aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-11-24 16:28:07 -0800
committerGravatar Antonio Sánchez <cantonios@google.com>2020-11-30 16:28:57 +0000
commit89f90b585d24b3c07946b4ffd8064e66ad5af94a (patch)
treec29344e3c03752faaaf2f8eee847811091688262
parentc5985c46f5de0a7a381262c5a8a973806db92f40 (diff)
AVX512 missing ops.
This allows the `packetmath` tests to pass for AVX512 on skylake. Made `half` and `bfloat16` consistent in terms of ops they support. Note the `log` tests are currently disabled for `bfloat16` since they fail due to poor precision (they were previously disabled for `Packet8bf` via test function specialization -- I just removed that specialization and disabled it in the generic test).
-rw-r--r--Eigen/src/Core/arch/AVX/PacketMath.h21
-rw-r--r--Eigen/src/Core/arch/AVX512/MathFunctions.h12
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h280
-rw-r--r--Eigen/src/Core/arch/AVX512/TypeCasting.h16
-rw-r--r--test/packetmath.cpp129
5 files changed, 280 insertions, 178 deletions
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index e9eaaa9e0..a9fc33791 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -105,7 +105,8 @@ template<> struct packet_traits<double> : default_packet_traits
HasBlend = 1,
HasRound = 1,
HasFloor = 1,
- HasCeil = 1
+ HasCeil = 1,
+ HasRint = 1
};
};
@@ -278,7 +279,15 @@ template<> EIGEN_STRONG_INLINE Packet8i pconj(const Packet8i& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet8f pmul<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_mul_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pmul<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_mul_pd(a,b); }
-
+template<> EIGEN_STRONG_INLINE Packet8i pmul<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_mullo_epi32(a,b);
+#else
+ const __m128i lo = _mm_mullo_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ const __m128i hi = _mm_mullo_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pdiv<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_div_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pdiv<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_div_pd(a,b); }
@@ -499,14 +508,14 @@ template<> EIGEN_STRONG_INLINE Packet8i pandnot<Packet8i>(const Packet8i& a, con
template<> EIGEN_STRONG_INLINE Packet8f pround<Packet8f>(const Packet8f& a)
{
- const Packet8f mask = pset1frombits<Packet8f>(0x80000000u);
- const Packet8f prev0dot5 = pset1frombits<Packet8f>(0x3EFFFFFFu);
+ const Packet8f mask = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x80000000u));
+ const Packet8f prev0dot5 = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x3EFFFFFFu));
return _mm256_round_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
}
template<> EIGEN_STRONG_INLINE Packet4d pround<Packet4d>(const Packet4d& a)
{
- const Packet4d mask = _mm256_castsi256_pd(_mm256_set_epi64x(0x8000000000000000ull, 0x8000000000000000ull, 0x8000000000000000ull, 0x8000000000000000ull));
- const Packet4d prev0dot5 = _mm256_castsi256_pd(_mm256_set_epi64x(0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull));
+ const Packet4d mask = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x8000000000000000ull));
+ const Packet4d prev0dot5 = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull));
return _mm256_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
}
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index bfd30c01a..2c34868a7 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -48,6 +48,7 @@ plog<Packet8d>(const Packet8d& _x) {
return plog_double(_x);
}
+F16_PACKET_FUNCTION(Packet16f, Packet16h, plog)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
#endif
@@ -174,6 +175,7 @@ pexp<Packet8d>(const Packet8d& _x) {
return pmax(pmul(x, e), _x);
}*/
+F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
// Functions for sqrt.
@@ -232,6 +234,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
}
#endif
+F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
// prsqrt for float.
@@ -256,7 +259,7 @@ prsqrt<Packet16f>(const Packet16f& _x) {
__mmask16 inf_mask = _mm512_cmp_ps_mask(_x, p16f_inf, _CMP_EQ_OQ);
__mmask16 not_pos_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LE_OQ);
__mmask16 not_finite_pos_mask = not_pos_mask | inf_mask;
-
+
// Compute an approximate result using the rsqrt intrinsic, forcing +inf
// for denormals for consistency with AVX and SSE implementations.
Packet16f y_approx = _mm512_rsqrt14_ps(_x);
@@ -281,6 +284,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
}
#endif
+F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
// prsqrt for double.
@@ -336,6 +340,7 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
}
+F16_PACKET_FUNCTION(Packet16f, Packet16h, plog1p)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
@@ -343,6 +348,7 @@ Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
+F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
#endif
@@ -367,6 +373,10 @@ ptanh<Packet16f>(const Packet16f& _x) {
return internal::generic_fast_tanh_float(_x);
}
+F16_PACKET_FUNCTION(Packet16f, Packet16h, psin)
+F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos)
+F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh)
+
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index bf7f0db4f..9acec3439 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -58,23 +58,35 @@ struct packet_traits<half> : default_packet_traits {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
- HasHalfPacket = 0,
+ HasHalfPacket = 1,
+
+ HasCmp = 1,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
- HasAbs = 0,
+ HasAbs = 1,
HasAbs2 = 0,
- HasMin = 0,
- HasMax = 0,
- HasConj = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
HasSetLinear = 0,
- HasSqrt = 0,
- HasRsqrt = 0,
- HasExp = 0,
- HasLog = 0,
- HasBlend = 0
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
};
};
@@ -87,6 +99,11 @@ template<> struct packet_traits<float> : default_packet_traits
AlignedOnScalar = 1,
size = 16,
HasHalfPacket = 1,
+
+ HasAbs = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
HasBlend = 0,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
@@ -105,7 +122,11 @@ template<> struct packet_traits<float> : default_packet_traits
HasErf = EIGEN_FAST_MATH,
#endif
HasCmp = 1,
- HasDiv = 1
+ HasDiv = 1,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
};
};
template<> struct packet_traits<double> : default_packet_traits
@@ -125,7 +146,11 @@ template<> struct packet_traits<double> : default_packet_traits
HasRsqrt = EIGEN_FAST_MATH,
#endif
HasCmp = 1,
- HasDiv = 1
+ HasDiv = 1,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
};
};
@@ -165,7 +190,7 @@ struct unpacket_traits<Packet16i> {
template<>
struct unpacket_traits<Packet16h> {
typedef Eigen::half type;
- typedef Packet16h half;
+ typedef Packet8h half;
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
@@ -188,10 +213,14 @@ EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) {
}
template <>
-EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(uint64_t from) {
+EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(const numext::uint64_t from) {
return _mm512_castsi512_pd(_mm512_set1_epi64(from));
}
+template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return _mm512_setzero_ps(); }
+template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); }
+template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); }
+
template <>
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
return _mm512_broadcastss_ps(_mm_load_ps1(from));
@@ -281,7 +310,7 @@ EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a,
template <>
EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a,
const Packet16i& b) {
- return _mm512_mul_epi32(a, b);
+ return _mm512_mullo_epi32(a, b);
}
template <>
@@ -482,6 +511,15 @@ EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
}
+template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet8d print<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION); }
+
+template<> EIGEN_STRONG_INLINE Packet16f pceil<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF); }
+template<> EIGEN_STRONG_INLINE Packet8d pceil<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF); }
+
+template<> EIGEN_STRONG_INLINE Packet16f pfloor<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF); }
+template<> EIGEN_STRONG_INLINE Packet8d pfloor<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF); }
+
template <>
EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {
return _mm512_set1_epi32(0xffffffffu);
@@ -598,6 +636,21 @@ EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,const Packet8d&
#endif
}
+template<> EIGEN_STRONG_INLINE Packet16f pround<Packet16f>(const Packet16f& a)
+{
+ // Work-around for default std::round rounding mode.
+ const Packet16f mask = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x80000000u));
+ const Packet16f prev0dot5 = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x3EFFFFFFu));
+ return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+template<> EIGEN_STRONG_INLINE Packet8d pround<Packet8d>(const Packet8d& a)
+{
+ // Work-around for default std::round rounding mode.
+ const Packet8d mask = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x8000000000000000ull));
+ const Packet8d prev0dot5 = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull));
+ return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
template<int N> EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) {
return _mm512_srai_epi32(a, N);
}
@@ -840,7 +893,24 @@ EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& expon
const Packet8d cst_half = pset1<Packet8d>(0.5);
const Packet8d cst_inv_mant_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d);
- return por(pand(a, cst_inv_mant_mask), cst_half);
+ return por(pand(a, cst_inv_mant_mask), cst_half);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) {
+ return pldexp_float(a,exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) {
+ // Build e=2^n by constructing the exponents in a 256-bit vector and
+ // shifting them to where they belong in double-precision values.
+ Packet8i cst_1023 = pset1<Packet8i>(1023);
+ __m256i emm0 = _mm512_cvtpd_epi32(exponent);
+ emm0 = _mm256_add_epi32(emm0, cst_1023);
+ emm0 = _mm256_shuffle_epi32(emm0, _MM_SHUFFLE(3, 1, 2, 0));
+ __m256i lo = _mm256_slli_epi64(emm0, 52);
+ __m256i hi = _mm256_slli_epi64(_mm256_srli_epi64(emm0, 32), 52);
+ __m512d b = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
+ return pmul(a, b);
}
#ifdef EIGEN_VECTORIZE_AVX512DQ
@@ -1270,22 +1340,6 @@ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket,
return _mm512_mask_blend_pd(m, elsePacket, thenPacket);
}
-template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
- return _mm512_cvttps_epi32(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
- return _mm512_cvtepi32_ps(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i,Packet16f>(const Packet16f& a) {
- return _mm512_castps_si512(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f,Packet16i>(const Packet16i& a) {
- return _mm512_castsi512_ps(a);
-}
-
// Packet math for Eigen::half
template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
return _mm256_set1_epi16(from.x);
@@ -1398,6 +1452,29 @@ template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) {
return ptrue(Packet8i(a));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) {
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm256_andnot_si256(sign_mask, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a,
+ const Packet16h& b) {
+ return float2half(pmin<Packet16f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a,
+ const Packet16h& b) {
+ return float2half(pmax<Packet16f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) {
+ return float2half(plset<Packet16f>(static_cast<float>(a)));
+}
+
template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) {
// in some cases Packet8i is a wrapper around __m256i, so we need to
// cast to Packet8i to call the correct overload.
@@ -1417,12 +1494,42 @@ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Pa
return _mm256_blendv_epi8(b, a, mask);
}
+template<> EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) {
+ return float2half(pround<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) {
+ return float2half(print<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) {
+ return float2half(pceil<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) {
+ return float2half(pfloor<Packet16f>(half2float(a)));
+}
+
template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
return Pack32To16(pcmp_eq(af, bf));
}
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a,const Packet16h& b) {
+ return Pack32To16(pcmp_le(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a,const Packet16h& b) {
+ return Pack32To16(pcmp_lt(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a,const Packet16h& b) {
+ return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) { return a; }
+
template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
Packet16h sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
return _mm256_xor_si256(a, sign_mask);
@@ -1461,6 +1568,25 @@ template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
return half(predux(from_float));
}
+template <>
+EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
+ Packet8h lane0 = _mm256_extractf128_si256(a, 0);
+ Packet8h lane1 = _mm256_extractf128_si256(a, 1);
+ return padd<Packet8h>(lane0, lane1);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet16h>(const Packet16h& a) {
+ Packet16f af = half2float(a);
+ float reduced = predux_max<Packet16f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet16h>(const Packet16h& a) {
+ Packet16f af = half2float(a);
+ float reduced = predux_min<Packet16f>(af);
+ return Eigen::half(reduced);
+}
+
template<> EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& from) {
Packet16f from_float = half2float(from);
return half(predux_mul(from_float));
@@ -1487,22 +1613,22 @@ template<> EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Pa
{
EIGEN_ALIGN64 half aux[16];
pstore(aux, from);
- to[stride*0].x = aux[0].x;
- to[stride*1].x = aux[1].x;
- to[stride*2].x = aux[2].x;
- to[stride*3].x = aux[3].x;
- to[stride*4].x = aux[4].x;
- to[stride*5].x = aux[5].x;
- to[stride*6].x = aux[6].x;
- to[stride*7].x = aux[7].x;
- to[stride*8].x = aux[8].x;
- to[stride*9].x = aux[9].x;
- to[stride*10].x = aux[10].x;
- to[stride*11].x = aux[11].x;
- to[stride*12].x = aux[12].x;
- to[stride*13].x = aux[13].x;
- to[stride*14].x = aux[14].x;
- to[stride*15].x = aux[15].x;
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+ to[stride*8] = aux[8];
+ to[stride*9] = aux[9];
+ to[stride*10] = aux[10];
+ to[stride*11] = aux[11];
+ to[stride*12] = aux[12];
+ to[stride*13] = aux[13];
+ to[stride*14] = aux[14];
+ to[stride*15] = aux[15];
}
EIGEN_STRONG_INLINE void
@@ -1694,7 +1820,7 @@ struct packet_traits<bfloat16> : default_packet_traits {
HasCos = EIGEN_FAST_MATH,
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
#ifdef EIGEN_VECTORIZE_AVX512DQ
- HasLog = 1,
+ HasLog = 1, // Currently fails test with bad accuracy.
HasLog1p = 1,
HasExpm1 = 1,
HasNdtri = 1,
@@ -1859,6 +1985,23 @@ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask,
return _mm256_blendv_epi8(b, a, mask);
}
+template<> EIGEN_STRONG_INLINE Packet16bf pround<Packet16bf>(const Packet16bf& a)
+{
+ return F32ToBf16(pround<Packet16f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf print<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(print<Packet16f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf pceil<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(pceil<Packet16f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf pfloor<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(pfloor<Packet16f>(Bf16ToF32(a)));
+}
+
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a,
const Packet16bf& b) {
@@ -1885,9 +2028,7 @@ EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a,
template <>
EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
- Packet16bf sign_mask;
- sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
- Packet16bf result;
+ Packet16bf sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
return _mm256_xor_si256(a, sign_mask);
}
@@ -1898,7 +2039,8 @@ EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) {
template <>
EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) {
- return F32ToBf16(pabs<Packet16f>(Bf16ToF32(a)));
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm256_andnot_si256(sign_mask, a);
}
template <>
@@ -1997,22 +2139,22 @@ EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to,
Index stride) {
EIGEN_ALIGN64 bfloat16 aux[16];
pstore(aux, from);
- to[stride*0].value = aux[0].value;
- to[stride*1].value = aux[1].value;
- to[stride*2].value = aux[2].value;
- to[stride*3].value = aux[3].value;
- to[stride*4].value = aux[4].value;
- to[stride*5].value = aux[5].value;
- to[stride*6].value = aux[6].value;
- to[stride*7].value = aux[7].value;
- to[stride*8].value = aux[8].value;
- to[stride*9].value = aux[9].value;
- to[stride*10].value = aux[10].value;
- to[stride*11].value = aux[11].value;
- to[stride*12].value = aux[12].value;
- to[stride*13].value = aux[13].value;
- to[stride*14].value = aux[14].value;
- to[stride*15].value = aux[15].value;
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+ to[stride*8] = aux[8];
+ to[stride*9] = aux[9];
+ to[stride*10] = aux[10];
+ to[stride*11] = aux[11];
+ to[stride*12] = aux[12];
+ to[stride*13] = aux[13];
+ to[stride*14] = aux[14];
+ to[stride*15] = aux[15];
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h
index e643b18a7..330412729 100644
--- a/Eigen/src/Core/arch/AVX512/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h
@@ -14,6 +14,22 @@ namespace Eigen {
namespace internal {
+template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
+ return _mm512_cvttps_epi32(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
+ return _mm512_cvtepi32_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
+ return _mm512_castps_si512(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(const Packet16i& a) {
+ return _mm512_castsi512_ps(a);
+}
+
template <>
struct type_casting_traits<half, float> {
enum {
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index d52f997dc..ae0ead820 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -618,7 +618,10 @@ void packetmath_real() {
test::packet_helper<PacketTraits::HasLog, Packet> h;
h.store(data2, internal::plog(h.load(data1)));
VERIFY((numext::isnan)(data2[0]));
- VERIFY_IS_APPROX(std::log(std::numeric_limits<Scalar>::epsilon()), data2[1]);
+ // TODO(cantonios): Re-enable for bfloat16.
+ if (!internal::is_same<Scalar, bfloat16>::value) {
+ VERIFY_IS_APPROX(std::log(data1[1]), data2[1]);
+ }
data1[0] = -std::numeric_limits<Scalar>::epsilon();
data1[1] = Scalar(0);
@@ -629,7 +632,10 @@ void packetmath_real() {
data1[0] = (std::numeric_limits<Scalar>::min)();
data1[1] = -(std::numeric_limits<Scalar>::min)();
h.store(data2, internal::plog(h.load(data1)));
- VERIFY_IS_APPROX(std::log((std::numeric_limits<Scalar>::min)()), data2[0]);
+ // TODO(cantonios): Re-enable for bfloat16.
+ if (!internal::is_same<Scalar, bfloat16>::value) {
+ VERIFY_IS_APPROX(std::log((std::numeric_limits<Scalar>::min)()), data2[0]);
+ }
VERIFY((numext::isnan)(data2[1]));
// Note: 32-bit arm always flushes denorms to zero.
@@ -731,54 +737,6 @@ void packetmath_real() {
VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \
}
-template <>
-void packetmath_real<bfloat16, typename internal::packet_traits<bfloat16>::type>(){
- typedef internal::packet_traits<bfloat16> PacketTraits;
- typedef internal::packet_traits<bfloat16>::type Packet;
-
- const int PacketSize = internal::unpacket_traits<Packet>::size;
- const int size = PacketSize * 4;
- EIGEN_ALIGN_MAX bfloat16 data1[PacketSize * 4];
- EIGEN_ALIGN_MAX bfloat16 data2[PacketSize * 4];
- EIGEN_ALIGN_MAX bfloat16 ref[PacketSize * 4];
-
- for (int i = 0; i < size; ++i) {
- data1[i] = bfloat16(internal::random<float>(0, 1) * std::pow(float(10), internal::random<float>(-6, 6)));
- data2[i] = bfloat16(internal::random<float>(0, 1) * std::pow(float(10), internal::random<float>(-6, 6)));
- data1[i] = bfloat16(0);
- }
-
- if (internal::random<float>(0, 1) < 0.1f) data1[internal::random<int>(0, PacketSize)] = bfloat16(0);
-
- CAST_CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog, bfloat16, float);
- CAST_CHECK_CWISE1_IF(PacketTraits::HasRsqrt, float(1) / std::sqrt, internal::prsqrt, bfloat16, float);
-
- for (int i = 0; i < size; ++i) {
- data1[i] = bfloat16(internal::random<float>(-1, 1) * std::pow(float(10), internal::random<float>(-3, 3)));
- data2[i] = bfloat16(internal::random<float>(-1, 1) * std::pow(float(10), internal::random<float>(-3, 3)));
- }
- CAST_CHECK_CWISE1_IF(PacketTraits::HasSin, std::sin, internal::psin, bfloat16, float);
- CAST_CHECK_CWISE1_IF(PacketTraits::HasCos, std::cos, internal::pcos, bfloat16, float);
- CAST_CHECK_CWISE1_IF(PacketTraits::HasTan, std::tan, internal::ptan, bfloat16, float);
-
- CAST_CHECK_CWISE1_IF(PacketTraits::HasRound, numext::round, internal::pround, bfloat16, float);
- CAST_CHECK_CWISE1_IF(PacketTraits::HasCeil, numext::ceil, internal::pceil, bfloat16, float);
- CAST_CHECK_CWISE1_IF(PacketTraits::HasFloor, numext::floor, internal::pfloor, bfloat16, float);
-
- for (int i = 0; i < size; ++i) {
- data1[i] = bfloat16(-1.5 + i);
- data2[i] = bfloat16(-1.5 + i);
- }
- CAST_CHECK_CWISE1_IF(PacketTraits::HasRound, numext::round, internal::pround, bfloat16, float);
-
- for (int i = 0; i < size; ++i) {
- data1[i] = bfloat16(internal::random<float>(-87, 88));
- data2[i] = bfloat16(internal::random<float>(-87, 88));
- }
- CAST_CHECK_CWISE1_IF(PacketTraits::HasExp, std::exp, internal::pexp, bfloat16, float);
-
-}
-
template <typename Scalar>
Scalar propagate_nan_max(const Scalar& a, const Scalar& b) {
if ((numext::isnan)(a)) return a;
@@ -793,6 +751,20 @@ Scalar propagate_nan_min(const Scalar& a, const Scalar& b) {
return (numext::mini)(a,b);
}
+template <typename Scalar>
+Scalar propagate_number_max(const Scalar& a, const Scalar& b) {
+ if ((numext::isnan)(a)) return b;
+ if ((numext::isnan)(b)) return a;
+ return (numext::maxi)(a,b);
+}
+
+template <typename Scalar>
+Scalar propagate_number_min(const Scalar& a, const Scalar& b) {
+ if ((numext::isnan)(a)) return b;
+ if ((numext::isnan)(b)) return a;
+ return (numext::mini)(a,b);
+}
+
template <typename Scalar, typename Packet>
void packetmath_notcomplex() {
typedef internal::packet_traits<Scalar> PacketTraits;
@@ -809,15 +781,9 @@ void packetmath_notcomplex() {
CHECK_CWISE2_IF(PacketTraits::HasMin, (std::min), internal::pmin);
CHECK_CWISE2_IF(PacketTraits::HasMax, (std::max), internal::pmax);
-#if EIGEN_HAS_CXX11_MATH
- using std::fmin;
- using std::fmax;
-#else
- using ::fmin;
- using ::fmax;
-#endif
- CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, (internal::pmin<PropagateNumbers>));
- CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pmax<PropagateNumbers>);
+
+ CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_number_min, internal::pmin<PropagateNumbers>);
+ CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_number_max, internal::pmax<PropagateNumbers>);
CHECK_CWISE1(numext::abs, internal::pabs);
CHECK_CWISE2_IF(PacketTraits::HasAbsDiff, REF_ABS_DIFF, internal::pabsdiff);
@@ -890,54 +856,13 @@ void packetmath_notcomplex() {
data1[i + PacketSize] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0);
}
// Note: NaN propagation is implementation defined for pmin/pmax, so we do not test it here.
- CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, (internal::pmin<PropagateNumbers>));
- CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pmax<PropagateNumbers>);
+ CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_number_min, (internal::pmin<PropagateNumbers>));
+ CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_number_max, internal::pmax<PropagateNumbers>);
CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_nan_min, (internal::pmin<PropagateNaN>));
CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_nan_max, internal::pmax<PropagateNaN>);
}
}
-template <>
-void packetmath_notcomplex<bfloat16, typename internal::packet_traits<bfloat16>::type>(){
- typedef bfloat16 Scalar;
- typedef internal::packet_traits<bfloat16>::type Packet;
- typedef internal::packet_traits<Scalar> PacketTraits;
- const int PacketSize = internal::unpacket_traits<Packet>::size;
-
- EIGEN_ALIGN_MAX Scalar data1[PacketSize * 4];
- EIGEN_ALIGN_MAX Scalar data2[PacketSize * 4];
- EIGEN_ALIGN_MAX Scalar ref[PacketSize * 4];
- Array<Scalar, Dynamic, 1>::Map(data1, PacketSize * 4).setRandom();
-
- ref[0] = data1[0];
- for (int i = 0; i < PacketSize; ++i) ref[0] = (std::min)(ref[0], data1[i]);
- VERIFY(internal::isApprox(ref[0], internal::predux_min(internal::pload<Packet>(data1))) && "internal::predux_min");
-
- VERIFY((!PacketTraits::Vectorizable) || PacketTraits::HasMin);
- VERIFY((!PacketTraits::Vectorizable) || PacketTraits::HasMax);
-
- CHECK_CWISE2_IF(PacketTraits::HasMin, (std::min), internal::pmin);
- CHECK_CWISE2_IF(PacketTraits::HasMax, (std::max), internal::pmax);
- CHECK_CWISE1(numext::abs, internal::pabs);
- CHECK_CWISE2_IF(PacketTraits::HasAbsDiff, REF_ABS_DIFF, internal::pabsdiff);
-
- ref[0] = data1[0];
- for (int i = 0; i < PacketSize; ++i) ref[0] = (std::max)(ref[0], data1[i]);
- VERIFY(internal::isApprox(ref[0], internal::predux_max(internal::pload<Packet>(data1))) && "internal::predux_max");
-
- {
- unsigned char* data1_bits = reinterpret_cast<unsigned char*>(data1);
- // predux_any
- for (unsigned int i = 0; i < PacketSize * sizeof(Scalar); ++i) data1_bits[i] = 0x0;
- VERIFY((!internal::predux_any(internal::pload<Packet>(data1))) && "internal::predux_any(0000)");
- for (int k = 0; k < PacketSize; ++k) {
- for (unsigned int i = 0; i < sizeof(Scalar); ++i) data1_bits[k * sizeof(Scalar) + i] = 0xff;
- VERIFY(internal::predux_any(internal::pload<Packet>(data1)) && "internal::predux_any(0101)");
- for (unsigned int i = 0; i < sizeof(Scalar); ++i) data1_bits[k * sizeof(Scalar) + i] = 0x00;
- }
- }
-}
-
template <typename Scalar, typename Packet, bool ConjLhs, bool ConjRhs>
void test_conj_helper(Scalar* data1, Scalar* data2, Scalar* ref, Scalar* pval) {
const int PacketSize = internal::unpacket_traits<Packet>::size;