aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX512
diff options
context:
space:
mode:
authorGravatar Guoqiang QI <425418567@qq.com>2020-10-15 00:54:45 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-10-15 00:54:45 +0000
commit4700713faf92a7b72a926ee1ac75f75d59e58887 (patch)
tree1962e245d0b890f5d18fa836e2bd3e313465d975 /Eigen/src/Core/arch/AVX512
parentaf6f43d7ff7a7c9cfa2a1355e2b7e60f94e192fe (diff)
Add AVX plog<Packet4d> and AVX512 plog<Packet8d> ops,also unified AVX512 plog<Packet16f> op with generic api
Diffstat (limited to 'Eigen/src/Core/arch/AVX512')
-rw-r--r--Eigen/src/Core/arch/AVX512/MathFunctions.h101
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h23
2 files changed, 29 insertions, 95 deletions
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index 83af5f5de..f6a43738d 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -35,104 +35,17 @@ namespace internal {
#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \
const Packet16bf p16bf_##NAME = preinterpret<Packet16bf,Packet16i>(pset1<Packet16i>(X))
-// Natural logarithm
-// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
-// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
-// be easily approximated by a polynomial centered on m=1 for stability.
#if defined(EIGEN_VECTORIZE_AVX512DQ)
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
plog<Packet16f>(const Packet16f& _x) {
- Packet16f x = _x;
- _EIGEN_DECLARE_CONST_Packet16f(1, 1.0f);
- _EIGEN_DECLARE_CONST_Packet16f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet16f(126f, 126.0f);
-
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inv_mant_mask, ~0x7f800000);
-
- // The smallest non denormalized float number.
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(min_norm_pos, 0x00800000);
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(minus_inf, 0xff800000);
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(pos_inf, 0x7f800000);
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(nan, 0x7fc00000);
-
- // Polynomial coefficients.
- _EIGEN_DECLARE_CONST_Packet16f(cephes_SQRTHF, 0.707106781186547524f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p0, 7.0376836292E-2f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p1, -1.1514610310E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p2, 1.1676998740E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p3, -1.2420140846E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p4, +1.4249322787E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p5, -1.6668057665E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p6, +2.0000714765E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p7, -2.4999993993E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p8, +3.3333331174E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_q1, -2.12194440e-4f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_q2, 0.693359375f);
-
- // invalid_mask is set to true when x is NaN
- __mmask16 invalid_mask = _mm512_cmp_ps_mask(x, _mm512_setzero_ps(), _CMP_NGE_UQ);
- __mmask16 iszero_mask = _mm512_cmp_ps_mask(x, _mm512_setzero_ps(), _CMP_EQ_OQ);
-
- // Truncate input values to the minimum positive normal.
- x = pmax(x, p16f_min_norm_pos);
-
- // Extract the shifted exponents.
- Packet16f emm0 = _mm512_cvtepi32_ps(_mm512_srli_epi32((preinterpret<Packet16i,Packet16f>(x)), 23));
- Packet16f e = _mm512_sub_ps(emm0, p16f_126f);
-
- // Set the exponents to -1, i.e. x are in the range [0.5,1).
- x = _mm512_and_ps(x, p16f_inv_mant_mask);
- x = _mm512_or_ps(x, p16f_half);
-
- // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
- // and shift by -1. The values are then centered around 0, which improves
- // the stability of the polynomial evaluation.
- // if( x < SQRTHF ) {
- // e -= 1;
- // x = x + x - 1.0;
- // } else { x = x - 1.0; }
- __mmask16 mask = _mm512_cmp_ps_mask(x, p16f_cephes_SQRTHF, _CMP_LT_OQ);
- Packet16f tmp = _mm512_mask_blend_ps(mask, _mm512_setzero_ps(), x);
- x = psub(x, p16f_1);
- e = psub(e, _mm512_mask_blend_ps(mask, _mm512_setzero_ps(), p16f_1));
- x = padd(x, tmp);
-
- Packet16f x2 = pmul(x, x);
- Packet16f x3 = pmul(x2, x);
-
- // Evaluate the polynomial approximant of degree 8 in three parts, probably
- // to improve instruction-level parallelism.
- Packet16f y, y1, y2;
- y = pmadd(p16f_cephes_log_p0, x, p16f_cephes_log_p1);
- y1 = pmadd(p16f_cephes_log_p3, x, p16f_cephes_log_p4);
- y2 = pmadd(p16f_cephes_log_p6, x, p16f_cephes_log_p7);
- y = pmadd(y, x, p16f_cephes_log_p2);
- y1 = pmadd(y1, x, p16f_cephes_log_p5);
- y2 = pmadd(y2, x, p16f_cephes_log_p8);
- y = pmadd(y, x3, y1);
- y = pmadd(y, x3, y2);
- y = pmul(y, x3);
-
- // Add the logarithm of the exponent back to the result of the interpolation.
- y1 = pmul(e, p16f_cephes_log_q1);
- tmp = pmul(x2, p16f_half);
- y = padd(y, y1);
- x = psub(x, tmp);
- y2 = pmul(e, p16f_cephes_log_q2);
- x = padd(x, y);
- x = padd(x, y2);
-
- __mmask16 pos_inf_mask = _mm512_cmp_ps_mask(_x,p16f_pos_inf,_CMP_EQ_OQ);
- // Filter out invalid inputs, i.e.:
- // - negative arg will be NAN,
- // - 0 will be -INF.
- // - +INF will be +INF
- return _mm512_mask_blend_ps(iszero_mask,
- _mm512_mask_blend_ps(invalid_mask,
- _mm512_mask_blend_ps(pos_inf_mask,x,p16f_pos_inf),
- p16f_nan),
- p16f_minus_inf);
+ return plog_float(_x);
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d
+plog<Packet8d>(const Packet8d& _x) {
+ return plog_double(_x);
}
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 8bb16ce3d..bf7f0db4f 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -118,6 +118,9 @@ template<> struct packet_traits<double> : default_packet_traits
size = 8,
HasHalfPacket = 1,
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
+#ifdef EIGEN_VECTORIZE_AVX512DQ
+ HasLog = 1,
+#endif
HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = EIGEN_FAST_MATH,
#endif
@@ -185,6 +188,11 @@ EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) {
}
template <>
+EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(uint64_t from) {
+ return _mm512_castsi512_pd(_mm512_set1_epi64(from));
+}
+
+template <>
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
return _mm512_broadcastss_ps(_mm_load_ps1(from));
}
@@ -821,6 +829,20 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
_mm512_set1_epi64(0x7fffffffffffffff)));
}
+template<>
+EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){
+ return pfrexp_float(a, exponent);
+}
+
+template<>
+EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent){
+ const Packet8d cst_1022d = pset1<Packet8d>(1022.0);
+ const Packet8d cst_half = pset1<Packet8d>(0.5);
+ const Packet8d cst_inv_mant_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
+ exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d);
+ return por(pand(a, cst_inv_mant_mask), cst_half);
+}
+
#ifdef EIGEN_VECTORIZE_AVX512DQ
// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
@@ -1264,7 +1286,6 @@ template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f,Packet16i>(const
return _mm512_castsi512_ps(a);
}
-
// Packet math for Eigen::half
template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
return _mm256_set1_epi16(from.x);