From 3012e755e92d3b3f01f8e7753b5e71cbeaaa40df Mon Sep 17 00:00:00 2001 From: Guoqiang QI <425418567@qq.com> Date: Tue, 15 Sep 2020 17:10:35 +0000 Subject: Add plog ops support packet2d for NEON --- .../Core/arch/Default/GenericPacketMathFunctions.h | 118 +++++++++++++++++++++ .../arch/Default/GenericPacketMathFunctionsFwd.h | 9 ++ Eigen/src/Core/arch/NEON/MathFunctions.h | 3 + Eigen/src/Core/arch/NEON/PacketMath.h | 8 +- Eigen/src/Core/arch/SSE/MathFunctions.h | 5 + Eigen/src/Core/arch/SSE/PacketMath.h | 6 ++ 6 files changed, 148 insertions(+), 1 deletion(-) (limited to 'Eigen/src/Core/arch') diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index e4a0c0919..a0bfada93 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -29,6 +29,16 @@ pfrexp_float(const Packet& a, Packet& exponent) { return por(pand(a, cst_inv_mant_mask), cst_half); } +template EIGEN_STRONG_INLINE Packet +pfrexp_double(const Packet& a, Packet& exponent) { + typedef typename unpacket_traits::integer_packet PacketI; + const Packet cst_1022d = pset1(1022.0); + const Packet cst_half = pset1(0.5); + const Packet cst_inv_mant_mask = pset1frombits(~0x7ff0000000000000u); + exponent = psub(pcast(plogical_shift_right<52>(preinterpret(a))), cst_1022d); + return por(pand(a, cst_inv_mant_mask), cst_half); +} + template EIGEN_STRONG_INLINE Packet pldexp_float(Packet a, Packet exponent) { @@ -139,6 +149,114 @@ Packet plog_float(const Packet _x) por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask)); } + +/* Returns the base e (2.718...) logarithm of x. + * The argument is separated into its exponent and fractional + * parts. If the exponent is between -1 and +1, the logarithm + * of the fraction is approximated by + * + * log(1+x) = x - 0.5 x**2 + x**3 P(x)/Q(x). + * + * Otherwise, setting z = 2(x-1)/x+1), + * log(x) = z + z**3 P(z)/Q(z). + * + * for more detail see: http://www.netlib.org/cephes/ + */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_double(const Packet _x) +{ + Packet x = _x; + + const Packet cst_1 = pset1(1.0); + const Packet cst_half = pset1(0.5); + // The smallest non denormalized float number. + const Packet cst_min_norm_pos = pset1frombits( 0x0010000000000000u); + const Packet cst_minus_inf = pset1frombits( 0xfff0000000000000u); + const Packet cst_pos_inf = pset1frombits( 0x7ff0000000000000u); + + // Polynomial Coefficients for log(1+x) = x - x**2/2 + x**3 P(x)/Q(x) + // 1/sqrt(2) <= x < sqrt(2) + const Packet cst_cephes_SQRTHF = pset1(0.70710678118654752440E0); + const Packet cst_cephes_log_p0 = pset1(1.01875663804580931796E-4); + const Packet cst_cephes_log_p1 = pset1(4.97494994976747001425E-1); + const Packet cst_cephes_log_p2 = pset1(4.70579119878881725854E0); + const Packet cst_cephes_log_p3 = pset1(1.44989225341610930846E1); + const Packet cst_cephes_log_p4 = pset1(1.79368678507819816313E1); + const Packet cst_cephes_log_p5 = pset1(7.70838733755885391666E0); + + const Packet cst_cephes_log_r0 = pset1(1.0); + const Packet cst_cephes_log_r1 = pset1(1.12873587189167450590E1); + const Packet cst_cephes_log_r2 = pset1(4.52279145837532221105E1); + const Packet cst_cephes_log_r3 = pset1(8.29875266912776603211E1); + const Packet cst_cephes_log_r4 = pset1(7.11544750618563894466E1); + const Packet cst_cephes_log_r5 = pset1(2.31251620126765340583E1); + + const Packet cst_cephes_log_q1 = pset1(-2.121944400546905827679e-4); + const Packet cst_cephes_log_q2 = pset1(0.693359375); + + // Truncate input values to the minimum positive normal. + x = pmax(x, cst_min_norm_pos); + + Packet e; + // extract significant in the range [0.5,1) and exponent + x = pfrexp(x,e); + + // Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) + // and shift by -1. The values are then centered around 0, which improves + // the stability of the polynomial evaluation. + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + Packet mask = pcmp_lt(x, cst_cephes_SQRTHF); + Packet tmp = pand(x, mask); + x = psub(x, cst_1); + e = psub(e, pand(cst_1, mask)); + x = padd(x, tmp); + + Packet x2 = pmul(x, x); + Packet x3 = pmul(x2, x); + + // Evaluate the polynomial approximant , probably to improve instruction-level parallelism. + // y = x * ( z * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) ); + Packet y, y1, y2,y_; + y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1); + y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4); + y = pmadd(y, x, cst_cephes_log_p2); + y1 = pmadd(y1, x, cst_cephes_log_p5); + y_ = pmadd(y, x3, y1); + + y = pmadd(cst_cephes_log_r0, x, cst_cephes_log_r1); + y1 = pmadd(cst_cephes_log_r3, x, cst_cephes_log_r4); + y = pmadd(y, x, cst_cephes_log_r2); + y1 = pmadd(y1, x, cst_cephes_log_r5); + y = pmadd(y, x3, y1); + + y_ = pmul(y_, x3); + y = pdiv(y_, y); + + // Add the logarithm of the exponent back to the result of the interpolation. + y1 = pmul(e, cst_cephes_log_q1); + tmp = pmul(x2, cst_half); + y = padd(y, y1); + x = psub(x, tmp); + y2 = pmul(e, cst_cephes_log_q2); + x = padd(x, y); + x = padd(x, y2); + + Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x)); + Packet iszero_mask = pcmp_eq(_x,pzero(_x)); + Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf); + // Filter out invalid inputs, i.e.: + // - negative arg will be NAN + // - 0 will be -INF + // - +INF will be +INF + return pselect(iszero_mask, cst_minus_inf, + por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask)); +} + /** \internal \returns log(1 + x) computed using W. Kahan's formula. See: http://www.plunk.org/~hatch/rightway.php */ diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index 68153cae3..0e02a1b20 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -20,6 +20,9 @@ namespace internal { template EIGEN_STRONG_INLINE Packet pfrexp_float(const Packet& a, Packet& exponent); +template EIGEN_STRONG_INLINE Packet +pfrexp_double(const Packet& a, Packet& exponent); + template EIGEN_STRONG_INLINE Packet pldexp_float(Packet a, Packet exponent); @@ -29,6 +32,12 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet plog_float(const Packet _x); +/** \internal \returns log(x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_double(const Packet _x); + /** \internal \returns log(1 + x) */ template Packet generic_plog1p(const Packet& x); diff --git a/Eigen/src/Core/arch/NEON/MathFunctions.h b/Eigen/src/Core/arch/NEON/MathFunctions.h index 8bea0ac3c..28167b904 100644 --- a/Eigen/src/Core/arch/NEON/MathFunctions.h +++ b/Eigen/src/Core/arch/NEON/MathFunctions.h @@ -51,6 +51,9 @@ BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh) template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d pexp(const Packet2d& x) { return pexp_double(x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d plog(const Packet2d& x) +{ return plog_double(x); } + #endif } // end namespace internal diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 2d0f91e2f..530adfeec 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -3579,7 +3579,7 @@ template<> struct packet_traits : default_packet_traits HasSin = 0, HasCos = 0, - HasLog = 0, + HasLog = 1, HasExp = 1, HasSqrt = 1, HasTanh = 0, @@ -3753,6 +3753,12 @@ template<> EIGEN_DEVICE_FUNC inline Packet2d pselect( const Packet2d& mask, cons template<> EIGEN_STRONG_INLINE Packet2d pldexp(const Packet2d& a, const Packet2d& exponent) { return pldexp_double(a, exponent); } +template<> EIGEN_STRONG_INLINE Packet2d pfrexp(const Packet2d& a, Packet2d& exponent) +{ return pfrexp_double(a,exponent); } + +template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(unsigned long from) +{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); } + #if EIGEN_FAST_MATH // Functions for sqrt support packet2d. diff --git a/Eigen/src/Core/arch/SSE/MathFunctions.h b/Eigen/src/Core/arch/SSE/MathFunctions.h index 92c1eecc7..71ec6f858 100644 --- a/Eigen/src/Core/arch/SSE/MathFunctions.h +++ b/Eigen/src/Core/arch/SSE/MathFunctions.h @@ -24,6 +24,11 @@ Packet4f plog(const Packet4f& _x) { return plog_float(_x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d plog(const Packet2d& _x) { + return plog_double(_x); +} + template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f plog1p(const Packet4f& _x) { return generic_plog1p(_x); diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 25705e7b2..c461420c5 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -132,6 +132,7 @@ struct packet_traits : default_packet_traits { HasCmp = 1, HasDiv = 1, + HasLog = 1, HasExp = 1, HasSqrt = 1, HasRsqrt = 1, @@ -227,6 +228,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pset1(const int& from) { re template<> EIGEN_STRONG_INLINE Packet16b pset1(const bool& from) { return _mm_set1_epi8(static_cast(from)); } template<> EIGEN_STRONG_INLINE Packet4f pset1frombits(unsigned int from) { return _mm_castsi128_ps(pset1(from)); } +template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(unsigned long from) { return _mm_castsi128_pd(_mm_set1_epi64x(from)); } template<> EIGEN_STRONG_INLINE Packet4f pzero(const Packet4f& /*a*/) { return _mm_setzero_ps(); } template<> EIGEN_STRONG_INLINE Packet2d pzero(const Packet2d& /*a*/) { return _mm_setzero_pd(); } @@ -753,6 +755,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Pack return pfrexp_float(a,exponent); } +template<> EIGEN_STRONG_INLINE Packet2d pfrexp(const Packet2d& a, Packet2d& exponent) { + return pfrexp_double(a,exponent); +} + template<> EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& exponent) { return pldexp_float(a,exponent); } -- cgit v1.2.3