aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX512/PacketMath.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-11-24 16:28:07 -0800
committerGravatar Antonio Sánchez <cantonios@google.com>2020-11-30 16:28:57 +0000
commit89f90b585d24b3c07946b4ffd8064e66ad5af94a (patch)
treec29344e3c03752faaaf2f8eee847811091688262 /Eigen/src/Core/arch/AVX512/PacketMath.h
parentc5985c46f5de0a7a381262c5a8a973806db92f40 (diff)
AVX512 missing ops.
This allows the `packetmath` tests to pass for AVX512 on skylake. Made `half` and `bfloat16` consistent in terms of ops they support. Note the `log` tests are currently disabled for `bfloat16` since they fail due to poor precision (they were previously disabled for `Packet8bf` via test function specialization -- I just removed that specialization and disabled it in the generic test).
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h280
1 files changed, 211 insertions, 69 deletions
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index bf7f0db4f..9acec3439 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -58,23 +58,35 @@ struct packet_traits<half> : default_packet_traits {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
- HasHalfPacket = 0,
+ HasHalfPacket = 1,
+
+ HasCmp = 1,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
- HasAbs = 0,
+ HasAbs = 1,
HasAbs2 = 0,
- HasMin = 0,
- HasMax = 0,
- HasConj = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
HasSetLinear = 0,
- HasSqrt = 0,
- HasRsqrt = 0,
- HasExp = 0,
- HasLog = 0,
- HasBlend = 0
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
};
};
@@ -87,6 +99,11 @@ template<> struct packet_traits<float> : default_packet_traits
AlignedOnScalar = 1,
size = 16,
HasHalfPacket = 1,
+
+ HasAbs = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
HasBlend = 0,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
@@ -105,7 +122,11 @@ template<> struct packet_traits<float> : default_packet_traits
HasErf = EIGEN_FAST_MATH,
#endif
HasCmp = 1,
- HasDiv = 1
+ HasDiv = 1,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
};
};
template<> struct packet_traits<double> : default_packet_traits
@@ -125,7 +146,11 @@ template<> struct packet_traits<double> : default_packet_traits
HasRsqrt = EIGEN_FAST_MATH,
#endif
HasCmp = 1,
- HasDiv = 1
+ HasDiv = 1,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
};
};
@@ -165,7 +190,7 @@ struct unpacket_traits<Packet16i> {
template<>
struct unpacket_traits<Packet16h> {
typedef Eigen::half type;
- typedef Packet16h half;
+ typedef Packet8h half;
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
@@ -188,10 +213,14 @@ EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) {
}
template <>
-EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(uint64_t from) {
+EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(const numext::uint64_t from) {
return _mm512_castsi512_pd(_mm512_set1_epi64(from));
}
+template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return _mm512_setzero_ps(); }
+template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); }
+template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); }
+
template <>
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
return _mm512_broadcastss_ps(_mm_load_ps1(from));
@@ -281,7 +310,7 @@ EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a,
template <>
EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a,
const Packet16i& b) {
- return _mm512_mul_epi32(a, b);
+ return _mm512_mullo_epi32(a, b);
}
template <>
@@ -482,6 +511,15 @@ EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
}
+template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet8d print<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION); }
+
+template<> EIGEN_STRONG_INLINE Packet16f pceil<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF); }
+template<> EIGEN_STRONG_INLINE Packet8d pceil<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF); }
+
+template<> EIGEN_STRONG_INLINE Packet16f pfloor<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF); }
+template<> EIGEN_STRONG_INLINE Packet8d pfloor<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF); }
+
template <>
EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {
return _mm512_set1_epi32(0xffffffffu);
@@ -598,6 +636,21 @@ EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,const Packet8d&
#endif
}
+template<> EIGEN_STRONG_INLINE Packet16f pround<Packet16f>(const Packet16f& a)
+{
+ // Work-around for default std::round rounding mode.
+ const Packet16f mask = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x80000000u));
+ const Packet16f prev0dot5 = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x3EFFFFFFu));
+ return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+template<> EIGEN_STRONG_INLINE Packet8d pround<Packet8d>(const Packet8d& a)
+{
+ // Work-around for default std::round rounding mode.
+ const Packet8d mask = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x8000000000000000ull));
+ const Packet8d prev0dot5 = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull));
+ return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
template<int N> EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) {
return _mm512_srai_epi32(a, N);
}
@@ -840,7 +893,24 @@ EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& expon
const Packet8d cst_half = pset1<Packet8d>(0.5);
const Packet8d cst_inv_mant_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d);
- return por(pand(a, cst_inv_mant_mask), cst_half);
+ return por(pand(a, cst_inv_mant_mask), cst_half);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) {
+ return pldexp_float(a,exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) {
+ // Build e=2^n by constructing the exponents in a 256-bit vector and
+ // shifting them to where they belong in double-precision values.
+ Packet8i cst_1023 = pset1<Packet8i>(1023);
+ __m256i emm0 = _mm512_cvtpd_epi32(exponent);
+ emm0 = _mm256_add_epi32(emm0, cst_1023);
+ emm0 = _mm256_shuffle_epi32(emm0, _MM_SHUFFLE(3, 1, 2, 0));
+ __m256i lo = _mm256_slli_epi64(emm0, 52);
+ __m256i hi = _mm256_slli_epi64(_mm256_srli_epi64(emm0, 32), 52);
+ __m512d b = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
+ return pmul(a, b);
}
#ifdef EIGEN_VECTORIZE_AVX512DQ
@@ -1270,22 +1340,6 @@ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket,
return _mm512_mask_blend_pd(m, elsePacket, thenPacket);
}
-template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
- return _mm512_cvttps_epi32(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
- return _mm512_cvtepi32_ps(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i,Packet16f>(const Packet16f& a) {
- return _mm512_castps_si512(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f,Packet16i>(const Packet16i& a) {
- return _mm512_castsi512_ps(a);
-}
-
// Packet math for Eigen::half
template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
return _mm256_set1_epi16(from.x);
@@ -1398,6 +1452,29 @@ template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) {
return ptrue(Packet8i(a));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) {
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm256_andnot_si256(sign_mask, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a,
+ const Packet16h& b) {
+ return float2half(pmin<Packet16f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a,
+ const Packet16h& b) {
+ return float2half(pmax<Packet16f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) {
+ return float2half(plset<Packet16f>(static_cast<float>(a)));
+}
+
template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) {
// in some cases Packet8i is a wrapper around __m256i, so we need to
// cast to Packet8i to call the correct overload.
@@ -1417,12 +1494,42 @@ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Pa
return _mm256_blendv_epi8(b, a, mask);
}
+template<> EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) {
+ return float2half(pround<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) {
+ return float2half(print<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) {
+ return float2half(pceil<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) {
+ return float2half(pfloor<Packet16f>(half2float(a)));
+}
+
template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
return Pack32To16(pcmp_eq(af, bf));
}
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a,const Packet16h& b) {
+ return Pack32To16(pcmp_le(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a,const Packet16h& b) {
+ return Pack32To16(pcmp_lt(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a,const Packet16h& b) {
+ return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) { return a; }
+
template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
Packet16h sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
return _mm256_xor_si256(a, sign_mask);
@@ -1461,6 +1568,25 @@ template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
return half(predux(from_float));
}
+template <>
+EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
+ Packet8h lane0 = _mm256_extractf128_si256(a, 0);
+ Packet8h lane1 = _mm256_extractf128_si256(a, 1);
+ return padd<Packet8h>(lane0, lane1);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet16h>(const Packet16h& a) {
+ Packet16f af = half2float(a);
+ float reduced = predux_max<Packet16f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet16h>(const Packet16h& a) {
+ Packet16f af = half2float(a);
+ float reduced = predux_min<Packet16f>(af);
+ return Eigen::half(reduced);
+}
+
template<> EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& from) {
Packet16f from_float = half2float(from);
return half(predux_mul(from_float));
@@ -1487,22 +1613,22 @@ template<> EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Pa
{
EIGEN_ALIGN64 half aux[16];
pstore(aux, from);
- to[stride*0].x = aux[0].x;
- to[stride*1].x = aux[1].x;
- to[stride*2].x = aux[2].x;
- to[stride*3].x = aux[3].x;
- to[stride*4].x = aux[4].x;
- to[stride*5].x = aux[5].x;
- to[stride*6].x = aux[6].x;
- to[stride*7].x = aux[7].x;
- to[stride*8].x = aux[8].x;
- to[stride*9].x = aux[9].x;
- to[stride*10].x = aux[10].x;
- to[stride*11].x = aux[11].x;
- to[stride*12].x = aux[12].x;
- to[stride*13].x = aux[13].x;
- to[stride*14].x = aux[14].x;
- to[stride*15].x = aux[15].x;
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+ to[stride*8] = aux[8];
+ to[stride*9] = aux[9];
+ to[stride*10] = aux[10];
+ to[stride*11] = aux[11];
+ to[stride*12] = aux[12];
+ to[stride*13] = aux[13];
+ to[stride*14] = aux[14];
+ to[stride*15] = aux[15];
}
EIGEN_STRONG_INLINE void
@@ -1694,7 +1820,7 @@ struct packet_traits<bfloat16> : default_packet_traits {
HasCos = EIGEN_FAST_MATH,
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
#ifdef EIGEN_VECTORIZE_AVX512DQ
- HasLog = 1,
+ HasLog = 1, // Currently fails test with bad accuracy.
HasLog1p = 1,
HasExpm1 = 1,
HasNdtri = 1,
@@ -1859,6 +1985,23 @@ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask,
return _mm256_blendv_epi8(b, a, mask);
}
+template<> EIGEN_STRONG_INLINE Packet16bf pround<Packet16bf>(const Packet16bf& a)
+{
+ return F32ToBf16(pround<Packet16f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf print<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(print<Packet16f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf pceil<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(pceil<Packet16f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf pfloor<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(pfloor<Packet16f>(Bf16ToF32(a)));
+}
+
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a,
const Packet16bf& b) {
@@ -1885,9 +2028,7 @@ EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a,
template <>
EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
- Packet16bf sign_mask;
- sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
- Packet16bf result;
+ Packet16bf sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
return _mm256_xor_si256(a, sign_mask);
}
@@ -1898,7 +2039,8 @@ EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) {
template <>
EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) {
- return F32ToBf16(pabs<Packet16f>(Bf16ToF32(a)));
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm256_andnot_si256(sign_mask, a);
}
template <>
@@ -1997,22 +2139,22 @@ EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to,
Index stride) {
EIGEN_ALIGN64 bfloat16 aux[16];
pstore(aux, from);
- to[stride*0].value = aux[0].value;
- to[stride*1].value = aux[1].value;
- to[stride*2].value = aux[2].value;
- to[stride*3].value = aux[3].value;
- to[stride*4].value = aux[4].value;
- to[stride*5].value = aux[5].value;
- to[stride*6].value = aux[6].value;
- to[stride*7].value = aux[7].value;
- to[stride*8].value = aux[8].value;
- to[stride*9].value = aux[9].value;
- to[stride*10].value = aux[10].value;
- to[stride*11].value = aux[11].value;
- to[stride*12].value = aux[12].value;
- to[stride*13].value = aux[13].value;
- to[stride*14].value = aux[14].value;
- to[stride*15].value = aux[15].value;
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+ to[stride*8] = aux[8];
+ to[stride*9] = aux[9];
+ to[stride*10] = aux[10];
+ to[stride*11] = aux[11];
+ to[stride*12] = aux[12];
+ to[stride*13] = aux[13];
+ to[stride*14] = aux[14];
+ to[stride*15] = aux[15];
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {