diff options
-rw-r--r-- | Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h | 512 | ||||
-rw-r--r-- | test/array_cwise.cpp | 18 |
2 files changed, 394 insertions, 136 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 42d310ab2..abdbcdbb9 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1,3 +1,4 @@ + // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // @@ -36,7 +37,7 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pfrexp_generic_get_biased_exponent(const Packet& a) { typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<Packet>::integer_packet PacketI; - EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1; + enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1}; return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a)))); } @@ -46,28 +47,29 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pfrexp_generic(const Packet& a, Packet& exponent) { typedef typename unpacket_traits<Packet>::type Scalar; typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI; - - EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT; - EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1; - EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1; - + enum { + TotalBits = sizeof(Scalar) * CHAR_BIT, + MantissaBits = numext::numeric_limits<Scalar>::digits - 1, + ExponentBits = int(TotalBits) - int(MantissaBits) - 1 + }; + EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask = - ~(((ScalarUI(1) << exponent_bits) - ScalarUI(1)) << mantissa_bits); // ~0x7f800000 + ~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000 const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask)); const Packet half = pset1<Packet>(Scalar(0.5)); const Packet zero = pzero(a); const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126 - // To handle denormals, normalize by multiplying by 2^(mantissa_bits+1). + // To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1). const Packet is_denormal = pcmp_lt(pabs(a), normal_min); - EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(mantissa_bits + 1); // 24 + EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24 // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr. const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24 const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor); const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a); // Determine exponent offset: -126 if normal, -126-24 if denormal - const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(exponent_bits-1)) - ScalarUI(2)); // -126 + const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126 Packet exponent_offset = pset1<Packet>(scalar_exponent_offset); const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24 exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset); @@ -76,7 +78,7 @@ Packet pfrexp_generic(const Packet& a, Packet& exponent) { exponent = pfrexp_generic_get_biased_exponent(normalized_a); // Zero, Inf and NaN return 'a' unmodified, exponent is zero // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero) - const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << exponent_bits) - ScalarUI(1)); // 255 + const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255 const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent); const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent)); const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half)); @@ -113,18 +115,20 @@ Packet pldexp_generic(const Packet& a, const Packet& exponent) { typedef typename unpacket_traits<Packet>::integer_packet PacketI; typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<PacketI>::type ScalarI; - EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT; - EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1; - EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1; + enum { + TotalBits = sizeof(Scalar) * CHAR_BIT, + MantissaBits = numext::numeric_limits<Scalar>::digits - 1, + ExponentBits = int(TotalBits) - int(MantissaBits) - 1 + }; - const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<exponent_bits) + ScalarI(mantissa_bits - 1))); // 278 - const PacketI bias = pset1<PacketI>((ScalarI(1)<<(exponent_bits-1)) - ScalarI(1)); // 127 + const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278 + const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127 const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); - Packet c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^b + Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) b = psub(psub(psub(e, b), b), b); // e - 3b - c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^(e-3*b) + c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b) out = pmul(out, c); return out; } @@ -890,112 +894,355 @@ Packet psqrt_complex(const Packet& a) { pselect(is_real_inf, real_inf_result,result)); } +// TODO(rmlarsen): The following set of utilities for double word arithmetic +// should perhaps be refactored as a separate file, since it would be generally +// useful for special function implementation etc. Writing the algorithms in +// terms if a double word type would also make the code more readable. + +// This function splits x into the nearest integer n and fractional part r, +// such that x = n + r holds exactly. +template<typename Packet> +EIGEN_STRONG_INLINE +void absolute_split(const Packet& x, Packet& n, Packet& r) { + n = pround(x); + r = psub(x, n); +} + +// This function computes the sum {s, r}, such that x + y = s_hi + s_lo +// holds exactly, and s_hi = fl(x+y), if |x| >= |y|. +template<typename Packet> +EIGEN_STRONG_INLINE +void fast_twosum(const Packet& x, const Packet& y, Packet& s_hi, Packet& s_lo) { + s_hi = padd(x, y); + const Packet t = psub(s_hi, x); + s_lo = psub(y, t); +} + +#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +// This function implements the extended precision product of +// a pair of floating point numbers. Given {x, y}, it computes the pair +// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and +// p_hi = fl(x * y). +template<typename Packet> +EIGEN_STRONG_INLINE +void twoprod(const Packet& x, const Packet& y, + Packet& p_hi, Packet& p_lo) { + p_hi = pmul(x, y); + p_lo = pmadd(x, y, pnegate(p_hi)); +} + +#else // This function implements the Veltkamp splitting. Given a floating point // number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds // exactly and that half of the significant of x fits in x_hi. -// This code corresponds to Algorithms 3 and 4 in -// https://hal.inria.fr/hal-01774587v2/document +// This is Algorithm 3 from Jean-Michel Muller, "Elementary Functions", +// 3rd edition, Birkh\"auser, 2016. template<typename Packet> EIGEN_STRONG_INLINE void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) { typedef typename unpacket_traits<Packet>::type Scalar; EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2; - Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr. - Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x); -#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD - x_hi = pmadd(pset1<Packet>(-shift_scale), x, gamma); -#else + const Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr. + const Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x); Packet rho = psub(x, gamma); x_hi = padd(rho, gamma); -#endif x_lo = psub(x, x_hi); } -// This function splits x into the nearest integer n and fractional part r, -// such that x = n + r holds exactly. +// This function implements Dekker's algorithm for products x * y. +// Given floating point numbers {x, y} computes the pair +// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and +// p_hi = fl(x * y). template<typename Packet> EIGEN_STRONG_INLINE -void integer_split(const Packet& x, Packet& n, Packet& r) { - n = pround(x); - r = psub(x, n); +void twoprod(const Packet& x, const Packet& y, + Packet& p_hi, Packet& p_lo) { + Packet x_hi, x_lo, y_hi, y_lo; + veltkamp_splitting(x, x_hi, x_lo); + veltkamp_splitting(y, y_hi, y_lo); + + p_hi = pmul(x, y); + p_lo = pmadd(x_hi, y_hi, pnegate(p_hi)); + p_lo = pmadd(x_hi, y_lo, p_lo); + p_lo = pmadd(x_lo, y_hi, p_lo); + p_lo = pmadd(x_lo, y_lo, p_lo); } -// This function implements Dekker's algorithm for two products {x * y1, x * y2} with -// a shared factor. Given floating point numbers {x, y1, y2} computes the pairs -// {p1, r1} and {p2, r2} such that x * y1 = p1 + r1 holds exactly and -// p1 = fl(x * y1), and x * y2 = p2 + r2 holds exactly and p2 = fl(x * y2). +#endif // EIGEN_HAS_SINGLE_INSTRUCTION_MADD + + +// This function implements Dekker's algorithm for the addition +// of two double word numbers represented by {x_hi, x_lo} and {y_hi, y_lo}. +// It returns the result as a pair {s_hi, s_lo} such that +// x_hi + x_lo + y_hi + y_lo = s_hi + s_lo holds exactly. +// This is Algorithm 5 from Jean-Michel Muller, "Elementary Functions", +// 3rd edition, Birkh\"auser, 2016. template<typename Packet> EIGEN_STRONG_INLINE -void double_dekker(const Packet& x, const Packet& y1, const Packet& y2, - Packet& p1, Packet& r1, Packet& p2, Packet& r2) { - Packet x_hi, x_lo, y1_hi, y1_lo, y2_hi, y2_lo; - veltkamp_splitting(x, x_hi, x_lo); - veltkamp_splitting(y1, y1_hi, y1_lo); - veltkamp_splitting(y2, y2_hi, y2_lo); - - p1 = pmul(x, y1); - r1 = pmadd(x_hi, y1_hi, pnegate(p1)); - r1 = pmadd(x_hi, y1_lo, r1); - r1 = pmadd(x_lo, y1_hi, r1); - r1 = pmadd(x_lo, y1_lo, r1); - - p2 = pmul(x, y2); - r2 = pmadd(x_hi, y2_hi, pnegate(p2)); - r2 = pmadd(x_hi, y2_lo, r2); - r2 = pmadd(x_lo, y2_hi, r2); - r2 = pmadd(x_lo, y2_lo, r2); + void twosum(const Packet& x_hi, const Packet& x_lo, + const Packet& y_hi, const Packet& y_lo, + Packet& s_hi, Packet& s_lo) { + const Packet x_greater_mask = pcmp_lt(pabs(y_hi), pabs(x_hi)); + Packet r_hi_1, r_lo_1; + fast_twosum(x_hi, y_hi,r_hi_1, r_lo_1); + Packet r_hi_2, r_lo_2; + fast_twosum(y_hi, x_hi,r_hi_2, r_lo_2); + const Packet r_hi = pselect(x_greater_mask, r_hi_1, r_hi_2); + + const Packet s1 = padd(padd(y_lo, r_lo_1), x_lo); + const Packet s2 = padd(padd(x_lo, r_lo_2), y_lo); + const Packet s = pselect(x_greater_mask, s1, s2); + + fast_twosum(r_hi, s, s_hi, s_lo); } -// This function implements the non-trivial case of pow(x,y) where x is -// positive and y is (possibly) non-integer. -// Formally, pow(x,y) = 2**(y * log2(x)) +// This is a version of twosum for double word numbers, +// which assumes that |x_hi| >= |y_hi|. +template<typename Packet> +EIGEN_STRONG_INLINE + void fast_twosum(const Packet& x_hi, const Packet& x_lo, + const Packet& y_hi, const Packet& y_lo, + Packet& s_hi, Packet& s_lo) { + Packet r_hi, r_lo; + fast_twosum(x_hi, y_hi, r_hi, r_lo); + const Packet s = padd(padd(y_lo, r_lo), x_lo); + fast_twosum(r_hi, s, s_hi, s_lo); +} + +// This function implements the multiplication of a double word +// number represented by {x_hi, x_lo} by a floating point number y. +// It returns the result as a pair {p_hi, p_lo} such that +// (x_hi + x_lo) * y = p_hi + p_lo hold with a relative error +// of less than 2*2^{-2p}, where p is the number of significand bit +// in the floating point type. +// This is Algorithm 7 from Jean-Michel Muller, "Elementary Functions", +// 3rd edition, Birkh\"auser, 2016. template<typename Packet> EIGEN_STRONG_INLINE -Packet generic_pow_impl(const Packet& x, const Packet& y) { +void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y, + Packet& p_hi, Packet& p_lo) { + Packet c_hi, c_lo1; + twoprod(x_hi, y, c_hi, c_lo1); + const Packet c_lo2 = pmul(x_lo, y); + Packet t_hi, t_lo1; + fast_twosum(c_hi, c_lo2, t_hi, t_lo1); + const Packet t_lo2 = padd(t_lo1, c_lo1); + fast_twosum(t_hi, t_lo2, p_hi, p_lo); +} + +// This function computes log2(x) and returns the result as a double word. +template <typename Scalar> +struct accurate_log2 { + template <typename Packet> + EIGEN_STRONG_INLINE + void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { + log2_x_hi = plog2(x); + log2_x_lo = pzero(x); + } +}; + +// This specialization uses a more accurate algorithm to compute log2(x) for +// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10. +// This additional accuracy is needed to counter the error-magnification +// inherent in multiplying by a potentially large exponent in pow(x,y). +// The minimax polynomial used was calculated using the Sollya tool. +// See sollya.org. +template <> +struct accurate_log2<float> { + template <typename Packet> + EIGEN_STRONG_INLINE + void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) { + // The function log(1+x)/x is approximated in the interval + // [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form + // Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))), + // where the degree 6 polynomial P(x) is evaluated in single precision, + // while the remaining 4 terms of Q(x), as well as the final multiplication by x + // to reconstruct log(1+x) are evaluated in extra precision using + // double word arithmetic. C0 through C3 are extra precise constants + // stored as double words. + // + // The polynomial coefficients were calculated using Sollya commands: + // > n = 10; + // > f = log2(1+x)/x; + // > interval = [sqrt(0.5)-1;sqrt(2)-1]; + // > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating); + + const Packet p6 = pset1<Packet>( 9.703654795885e-2f); + const Packet p5 = pset1<Packet>(-0.1690667718648f); + const Packet p4 = pset1<Packet>( 0.1720575392246f); + const Packet p3 = pset1<Packet>(-0.1789081543684f); + const Packet p2 = pset1<Packet>( 0.2050433009862f); + const Packet p1 = pset1<Packet>(-0.2404672354459f); + const Packet p0 = pset1<Packet>( 0.2885761857032f); + + const Packet C3_hi = pset1<Packet>(-0.360674142838f); + const Packet C3_lo = pset1<Packet>(-6.13283912543e-09f); + const Packet C2_hi = pset1<Packet>(0.480897903442f); + const Packet C2_lo = pset1<Packet>(-1.44861207474e-08f); + const Packet C1_hi = pset1<Packet>(-0.721347510815f); + const Packet C1_lo = pset1<Packet>(-4.84483164698e-09f); + const Packet C0_hi = pset1<Packet>(1.44269502163f); + const Packet C0_lo = pset1<Packet>(2.01711713999e-08f); + const Packet one = pset1<Packet>(1.0f); + + const Packet x = psub(z, one); + // Evaluate P(x) in working precision. + // We evaluate it in multiple parts to improve instruction level + // parallelism. + Packet x2 = pmul(x,x); + Packet p_even = pmadd(p6, x2, p4); + p_even = pmadd(p_even, x2, p2); + p_even = pmadd(p_even, x2, p0); + Packet p_odd = pmadd(p5, x2, p3); + p_odd = pmadd(p_odd, x2, p1); + Packet p = pmadd(p_odd, x, p_even); + + // Now evaluate the low-order tems of Q(x) in double word precision. + // In the following, due to the alternating signs and the fact that + // |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use + // fast_twosum instead of the slower twosum. + Packet q_hi, q_lo; + Packet t_hi, t_lo; + // C3 + x * p(x) + twoprod(p, x, t_hi, t_lo); + fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo); + // C2 + x * p(x) + twoprod(q_hi, q_lo, x, t_hi, t_lo); + fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo); + // C1 + x * p(x) + twoprod(q_hi, q_lo, x, t_hi, t_lo); + fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo); + // C0 + x * p(x) + twoprod(q_hi, q_lo, x, t_hi, t_lo); + fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo); + + // log(z) ~= x * Q(x) + twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo); + } +}; + +// This function computes exp2(x) (i.e. 2**x). +template <typename Scalar> +struct fast_accurate_exp2 { + template <typename Packet> + EIGEN_STRONG_INLINE + Packet operator()(const Packet& x) { + // TODO(rmlarsen): Add a pexp2 packetop. + return pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), x)); + } +}; + +// This specialization uses a faster algorithm to compute exp2(x) for floats +// in [-0.5;0.5] with a relative accuracy of 1 ulp. +// The minimax polynomial used was calculated using the Sollya tool. +// See sollya.org. +template <> +struct fast_accurate_exp2<float> { + template <typename Packet> + EIGEN_STRONG_INLINE + Packet operator()(const Packet& x) { + // This function approximates exp2(x) by a degree 6 polynomial of the form + // Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in + // single precision, and the remaining steps are evaluated with extra precision using + // double word arithmetic. C is an extra precise constant stored as a double word. + // + // The polynomial coefficients were calculated using Sollya commands: + // > n = 6; + // > f = 2^x; + // > interval = [-0.5;0.5]; + // > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating); + + const Packet p4 = pset1<Packet>(1.539513905e-4f); + const Packet p3 = pset1<Packet>(1.340007293e-3f); + const Packet p2 = pset1<Packet>(9.618283249e-3f); + const Packet p1 = pset1<Packet>(5.550328270e-2f); + const Packet p0 = pset1<Packet>(0.2402264923f); + + const Packet C_hi = pset1<Packet>(0.6931471825f); + const Packet C_lo = pset1<Packet>(2.36836577e-08f); + const Packet one = pset1<Packet>(1.0f); + + // Evaluate P(x) in working precision. + // We evaluate even and odd parts of the polynomial separately + // to gain some instruction level parallelism. + Packet x2 = pmul(x,x); + Packet p_even = pmadd(p4, x2, p2); + p_even = pmadd(p_even, x2, p0); + Packet p_odd = pmadd(p3, x2, p1); + Packet p = pmadd(p_odd, x, p_even); + + // Evaluate the remaining terms of Q(x) with extra precision using + // double word arithmetic. + Packet p_hi, p_lo; + // x * p(x) + twoprod(p, x, p_hi, p_lo); + // C + x * p(x) + Packet q1_hi, q1_lo; + twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo); + // x * (C + x * p(x)) + Packet q2_hi, q2_lo; + twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo); + // 1 + x * (C + x * p(x)) + Packet q3_hi, q3_lo; + // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum + // for adding it to unity here. + fast_twosum(one, q2_hi, q3_hi, q3_lo); + return padd(q3_hi, padd(q2_lo, q3_lo)); + } +}; + +// This function implements the non-trivial case of pow(x,y) where x is +// positive and y is (possibly) non-integer. +// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x. +// TODO(rmlarsen): We should probably add this as a packet up 'ppow', to make it +// easier to specialize or turn off for specific types and/or backends.x +template <typename Packet> +EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) { typedef typename unpacket_traits<Packet>::type Scalar; // Split x into exponent e_x and mantissa m_x. Packet e_x; Packet m_x = pfrexp(x, e_x); - // Adjust m_x to lie in [0.75:1.5) to minimize absolute error in log2(m_x). - Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(Scalar(0.75))); + // Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x). + EIGEN_CONSTEXPR Scalar sqrt_half = Scalar(0.70710678118654752440); + const Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(sqrt_half)); m_x = pselect(m_x_scale_mask, pmul(pset1<Packet>(Scalar(2)), m_x), m_x); e_x = pselect(m_x_scale_mask, psub(e_x, pset1<Packet>(Scalar(1))), e_x); - Packet r_x = plog2(m_x); + // Compute log2(m_x) with 6 extra bits of accuracy. + Packet rx_hi, rx_lo; + accurate_log2<Scalar>()(m_x, rx_hi, rx_lo); // Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled - // precision using Dekker's algorithm. + // precision using double word arithmetic. Packet f1_hi, f1_lo, f2_hi, f2_lo; - double_dekker(y, e_x, r_x, f1_hi, f1_lo, f2_hi, f2_lo); - - // Separate f into integer and fractional parts, keeping f1_hi, and f2_hi - // separate to avoid cancellation. - Packet n1, r1, n2, r2; - integer_split(f1_hi, n1, r1); - integer_split(f2_hi, n2, r2); - - // Add up integer parts and sum the remainders. - Packet n_z = padd(n1, n2); - // Notice: I experimented with using compensated (Kahan) summation here, - // but it does not seem to matter. - Packet rem = padd(padd(f1_lo, f2_lo), padd(r1, r2)); - - // Extract any additional integer part that may have accumulated in rem. - Packet nrem, r_z; - integer_split(rem, nrem, r_z); - n_z = padd(n_z, nrem); + twoprod(e_x, y, f1_hi, f1_lo); + twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo); + // Sum the two terms in f using double word arithmetic. We know + // that |e_x| > |log2(m_x)|, except for the case where e_x==0. + // This means that we can use fast_twosum(f1,f2). + // In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any + // accuracy by violating the assumption of fast_twosum, because + // it's a no-op. + Packet f_hi, f_lo; + fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo); + + // Split f into integer and fractional parts. + Packet n_z, r_z; + absolute_split(f_hi, n_z, r_z); + r_z = padd(r_z, f_lo); + Packet n_r; + absolute_split(r_z, n_r, r_z); + n_z = padd(n_z, n_r); // We now have an accurate split of f = n_z + r_z and can compute - // x^y = 2**{n_z + r_z) = exp(ln(2) * r_z) * 2**{n_z}. - // The first factor we compute by calling pexp(), while multiplication - // by an integer power of 2 can be done exactly using pldexp(). - // Note: I experimented with using Dekker's algorithms for the - // multiplication by ln(2) here, but did not see any difference. - Packet e_r = pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), r_z)); - // TODO: investigate bounds of e_r and n_z, potentially using faster - // implementation of ldexp. + // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}. + // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy + // using a specialized algorithm. Multiplication by the second factor can + // be done exactly using pldexp(), since it is an integer power of 2. + // Packet e_r = fast_accurate_exp2<Scalar>()(r_z); + const Packet e_r = fast_accurate_exp2<Scalar>()(r_z); return pldexp(e_r, n_z); } @@ -1005,66 +1252,75 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet generic_pow(const Packet& x, const Packet& y) { typedef typename unpacket_traits<Packet>::type Scalar; + const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity()); const Packet cst_zero = pset1<Packet>(Scalar(0)); const Packet cst_one = pset1<Packet>(Scalar(1)); - const Packet cst_half = pset1<Packet>(Scalar(0.5)); const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN()); - Packet abs_x = pabs(x); + const Packet abs_x = pabs(x); // Predicates for sign and magnitude of x. - Packet x_is_zero = pcmp_eq(x, cst_zero); - Packet x_is_neg = pcmp_lt(x, cst_zero); - Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); - Packet abs_x_is_one = pcmp_eq(abs_x, cst_one); - Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x); - Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one); - Packet x_is_one = pandnot(abs_x_is_one, x_is_neg); - Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg); - Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x)); + const Packet x_is_zero = pcmp_eq(x, cst_zero); + const Packet x_is_neg = pcmp_lt(x, cst_zero); + const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); + const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one); + const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x); + const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one); + const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg); + const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg); + const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x)); // Predicates for sign and magnitude of y. - Packet y_is_zero = pcmp_eq(y, cst_zero); - Packet y_is_neg = pcmp_lt(y, cst_zero); - Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg)); - Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y)); - Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf); - // |y| is so large that (1+eps)^y over- or underflows. + const Packet y_is_one = pcmp_eq(y, cst_one); + const Packet y_is_zero = pcmp_eq(y, cst_zero); + const Packet y_is_neg = pcmp_lt(y, cst_zero); + const Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg)); + const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y)); + const Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf); EIGEN_CONSTEXPR Scalar huge_exponent = - (std::numeric_limits<Scalar>::digits * Scalar(EIGEN_LOG2E)) / + (std::numeric_limits<Scalar>::max_exponent * Scalar(EIGEN_LN2)) / std::numeric_limits<Scalar>::epsilon(); - Packet abs_y_is_huge = pcmp_lt(pset1<Packet>(huge_exponent), pabs(y)); + const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y)); // Predicates for whether y is integer and/or even. - Packet y_is_int = pcmp_eq(pfloor(y), y); - Packet y_div_2 = pmul(y, cst_half); - Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2); + const Packet y_is_int = pcmp_eq(pfloor(y), y); + const Packet y_div_2 = pmul(y, pset1<Packet>(Scalar(0.5))); + const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2); // Predicates encoding special cases for the value of pow(x,y) - Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), y_is_int), abs_y_is_inf); - Packet pow_is_one = por(por(x_is_one, y_is_zero), - pand(x_is_neg_one, - por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x)))); - Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan)); - Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos), pand(abs_x_is_inf, y_is_neg)), - pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_pos)), - pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_neg)); - Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)), - pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_neg)), - pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_pos)); + const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), + y_is_int), + abs_y_is_inf); + const Packet pow_is_one = por(por(x_is_one, y_is_zero), + pand(x_is_neg_one, + por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x)))); + const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan)); + const Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos), + pand(abs_x_is_inf, y_is_neg)), + pand(pand(abs_x_is_lt_one, abs_y_is_huge), + y_is_pos)), + pand(pand(abs_x_is_gt_one, abs_y_is_huge), + y_is_neg)); + const Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg), + pand(abs_x_is_inf, y_is_pos)), + pand(pand(abs_x_is_lt_one, abs_y_is_huge), + y_is_neg)), + pand(pand(abs_x_is_gt_one, abs_y_is_huge), + y_is_pos)); // General computation of pow(x,y) for positive x or negative x and integer y. - Packet negate_pow_abs = pandnot(x_is_neg, y_is_even); - Packet pow_abs = generic_pow_impl(abs_x, y); - - return pselect(pow_is_one, cst_one, - pselect(pow_is_nan, cst_nan, - pselect(pow_is_inf, cst_pos_inf, - pselect(pow_is_zero, cst_zero, - pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))); + const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even); + const Packet pow_abs = generic_pow_impl(abs_x, y); + return pselect(y_is_one, x, + pselect(pow_is_one, cst_one, + pselect(pow_is_nan, cst_nan, + pselect(pow_is_inf, cst_pos_inf, + pselect(pow_is_zero, cst_zero, + pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))))); } + /* polevl (modified for Eigen) * * Evaluate polynomial diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index a1529bc96..7f7e44f89 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -14,6 +14,7 @@ template<typename Scalar> void pow_test() { const Scalar zero = Scalar(0); + const Scalar eps = std::numeric_limits<Scalar>::epsilon(); const Scalar one = Scalar(1); const Scalar two = Scalar(2); const Scalar three = Scalar(3); @@ -21,20 +22,25 @@ void pow_test() { const Scalar sqrt2 = Scalar(std::sqrt(2)); const Scalar inf = std::numeric_limits<Scalar>::infinity(); const Scalar nan = std::numeric_limits<Scalar>::quiet_NaN(); + const Scalar denorm_min = std::numeric_limits<Scalar>::denorm_min(); const Scalar min = (std::numeric_limits<Scalar>::min)(); const Scalar max = (std::numeric_limits<Scalar>::max)(); + const Scalar max_exp = (static_cast<Scalar>(std::numeric_limits<Scalar>::max_exponent) * Scalar(EIGEN_LN2)) / eps; + const static Scalar abs_vals[] = {zero, + denorm_min, + min, + eps, sqrt_half, one, sqrt2, two, three, - min, + max_exp, max, inf, nan}; - - const int abs_cases = 10; + const int abs_cases = 13; const int num_cases = 2*abs_cases * 2*abs_cases; // Repeat the same value to make sure we hit the vectorized path. const int num_repeats = 32; @@ -64,10 +70,7 @@ void pow_test() { bool all_pass = true; for (int i = 0; i < 1; ++i) { for (int j = 0; j < num_cases; ++j) { - // TODO(rmlarsen): Skip tests that trigger a known bug in pldexp for now. - if (std::abs(x(i,j)) == max || std::abs(x(i,j)) == min) continue; - - Scalar e = numext::pow(x(i,j), y(i,j)); + Scalar e = static_cast<Scalar>(std::pow(x(i,j), y(i,j))); Scalar a = actual(i, j); bool fail = !(a==e) && !internal::isApprox(a, e, tol) && !((numext::isnan)(a) && (numext::isnan)(e)); all_pass &= !fail; @@ -79,7 +82,6 @@ void pow_test() { VERIFY(all_pass); } - template<typename ArrayType> void array(const ArrayType& m) { typedef typename ArrayType::Scalar Scalar; |