From 88d4c6d4c870f53d129ab5f8b43e01812d9b500e Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Mon, 22 Feb 2021 16:06:00 -0800 Subject: Accurate pow, part 2. This change adds specializations of log2 and exp2 for double that make pow accurate the 1 ULP. Speed for AVX-512 is within 0.5% of the currect implementation. --- .../Core/arch/Default/GenericPacketMathFunctions.h | 225 ++++++++++++++++++++- 1 file changed, 222 insertions(+), 3 deletions(-) (limited to 'Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h') diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index e4c7e6223..d8603b406 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -484,7 +484,7 @@ Packet pexp_float(const Packet _x) y1 = pmadd(y1, r, cst_cephes_exp_p5); y = pmadd(y, r3, y1); y = pmadd(y, r2, y2); - + // Return 2^m * exp(r). // TODO: replace pldexp with faster implementation since y in [-1, 1). return pmax(pldexp(y,m), _x); @@ -1003,6 +1003,20 @@ EIGEN_STRONG_INLINE fast_twosum(r_hi, s, s_hi, s_lo); } +// This is a version of twosum for adding a floating point number x to +// double word number {y_hi, y_lo} number, with the assumption +// that |x| >= |y_hi|. +template +EIGEN_STRONG_INLINE +void fast_twosum(const Packet& x, + const Packet& y_hi, const Packet& y_lo, + Packet& s_hi, Packet& s_lo) { + Packet r_hi, r_lo; + fast_twosum(x, y_hi, r_hi, r_lo); + const Packet s = padd(y_lo, r_lo); + fast_twosum(r_hi, s, s_hi, s_lo); +} + // This function implements the multiplication of a double word // number represented by {x_hi, x_lo} by a floating point number y. // It returns the result as a pair {p_hi, p_lo} such that @@ -1024,6 +1038,50 @@ void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y, fast_twosum(t_hi, t_lo2, p_hi, p_lo); } +// This function implements the multiplication of two double word +// numbers represented by {x_hi, x_lo} and {y_hi, y_lo}. +// It returns the result as a pair {p_hi, p_lo} such that +// (x_hi + x_lo) * (y_hi + y_lo) = p_hi + p_lo holds with a relative error +// of less than 2*2^{-2p}, where p is the number of significand bit +// in the floating point type. +template +EIGEN_STRONG_INLINE +void twoprod(const Packet& x_hi, const Packet& x_lo, + const Packet& y_hi, const Packet& y_lo, + Packet& p_hi, Packet& p_lo) { + Packet p_hi_hi, p_hi_lo; + twoprod(x_hi, x_lo, y_hi, p_hi_hi, p_hi_lo); + Packet p_lo_hi, p_lo_lo; + twoprod(x_hi, x_lo, y_lo, p_lo_hi, p_lo_lo); + fast_twosum(p_hi_hi, p_hi_lo, p_lo_hi, p_lo_lo, p_hi, p_lo); +} + +// This function computes the reciprocal of a floating point number +// with extra precision and returns the result as a double word. +template +void doubleword_reciprocal(const Packet& x, Packet& recip_hi, Packet& recip_lo) { + typedef typename unpacket_traits::type Scalar; + // 1. Approximate the reciprocal as the reciprocal of the high order element. + Packet approx_recip = prsqrt(x); + approx_recip = pmul(approx_recip, approx_recip); + + // 2. Run one step of Newton-Raphson iteration in double word arithmetic + // to get the bottom half. The NR iteration for reciprocal of 'a' is + // x_{i+1} = x_i * (2 - a * x_i) + + // -a*x_i + Packet t1_hi, t1_lo; + twoprod(pnegate(x), approx_recip, t1_hi, t1_lo); + // 2 - a*x_i + Packet t2_hi, t2_lo; + fast_twosum(pset1(Scalar(2)), t1_hi, t2_hi, t2_lo); + Packet t3_hi, t3_lo; + fast_twosum(t2_hi, padd(t2_lo, t1_lo), t3_hi, t3_lo); + // x_i * (2 - a * x_i) + twoprod(t3_hi, t3_lo, approx_recip, recip_hi, recip_lo); +} + + // This function computes log2(x) and returns the result as a double word. template struct accurate_log2 { @@ -1115,6 +1173,101 @@ struct accurate_log2 { } }; +// This specialization uses a more accurate algorithm to compute log2(x) for +// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18. +// This additional accuracy is needed to counter the error-magnification +// inherent in multiplying by a potentially large exponent in pow(x,y). +// The minimax polynomial used was calculated using the Sollya tool. +// See sollya.org. + +template <> +struct accurate_log2 { + template + EIGEN_STRONG_INLINE + void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { + // We use a transformation of variables: + // r = c * (x-1) / (x+1), + // such that + // log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r). + // The function f(r) can be approximated well using an odd polynomial + // of the form + // P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r, + // For the implementation of log2 here, Q is of degree 6 with + // coefficient represented in working precision (double), while C is a + // constant represented in extra precision as a double word to achieve + // full accuracy. + // + // The polynomial coefficients were computed by the Sollya script: + // + // c = 2 / log(2); + // trans = c * (x-1)/(x+1); + // itrans = (1+x/c)/(1-x/c); + // interval=[trans(sqrt(0.5)); trans(sqrt(2))]; + // print(interval); + // f = log2(itrans(x)); + // p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating); + const Packet q12 = pset1(2.87074255468000586e-9); + const Packet q10 = pset1(2.38957980901884082e-8); + const Packet q8 = pset1(2.31032094540014656e-7); + const Packet q6 = pset1(2.27279857398537278e-6); + const Packet q4 = pset1(2.31271023278625638e-5); + const Packet q2 = pset1(2.47556738444535513e-4); + const Packet q0 = pset1(2.88543873228900172e-3); + const Packet C_hi = pset1(0.0400377511598501157); + const Packet C_lo = pset1(-4.77726582251425391e-19); + const Packet one = pset1(1.0); + + const Packet cst_2_log2e_hi = pset1(2.88539008177792677); + const Packet cst_2_log2e_lo = pset1(4.07660016854549667e-17); + // c * (x - 1) + Packet num_hi, num_lo; + twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), num_hi, num_lo); + // TODO(rmlarsen): Investigate if using the division algorithm by + // Muller et al. is faster/more accurate. + // 1 / (x + 1) + Packet denom_hi, denom_lo; + doubleword_reciprocal(padd(x, one), denom_hi, denom_lo); + // r = c * (x-1) / (x+1), + Packet r_hi, r_lo; + twoprod(num_hi, num_lo, denom_hi, denom_lo, r_hi, r_lo); + // r2 = r * r + Packet r2_hi, r2_lo; + twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo); + // r4 = r2 * r2 + Packet r4_hi, r4_lo; + twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo); + + // Evaluate Q(r^2) in working precision. We evaluate it in two parts + // (even and odd in r^2) to improve instruction level parallelism. + Packet q_even = pmadd(q12, r4_hi, q8); + Packet q_odd = pmadd(q10, r4_hi, q6); + q_even = pmadd(q_even, r4_hi, q4); + q_odd = pmadd(q_odd, r4_hi, q2); + q_even = pmadd(q_even, r4_hi, q0); + Packet q = pmadd(q_odd, r2_hi, q_even); + + // Now evaluate the low order terms of P(x) in double word precision. + // In the following, due to the increasing magnitude of the coefficients + // and r being constrained to [-0.5, 0.5] we can use fast_twosum instead + // of the slower twosum. + // Q(r^2) * r^2 + Packet p_hi, p_lo; + twoprod(r2_hi, r2_lo, q, p_hi, p_lo); + // Q(r^2) * r^2 + C + Packet p1_hi, p1_lo; + fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo); + // (Q(r^2) * r^2 + C) * r^2 + Packet p2_hi, p2_lo; + twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo); + // ((Q(r^2) * r^2 + C) * r^2 + 1) + Packet p3_hi, p3_lo; + fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo); + + // log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r + twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo); + } +}; + // This function computes exp2(x) (i.e. 2**x). template struct fast_accurate_exp2 { @@ -1161,8 +1314,75 @@ struct fast_accurate_exp2 { // to gain some instruction level parallelism. Packet x2 = pmul(x,x); Packet p_even = pmadd(p4, x2, p2); - p_even = pmadd(p_even, x2, p0); Packet p_odd = pmadd(p3, x2, p1); + p_even = pmadd(p_even, x2, p0); + Packet p = pmadd(p_odd, x, p_even); + + // Evaluate the remaining terms of Q(x) with extra precision using + // double word arithmetic. + Packet p_hi, p_lo; + // x * p(x) + twoprod(p, x, p_hi, p_lo); + // C + x * p(x) + Packet q1_hi, q1_lo; + twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo); + // x * (C + x * p(x)) + Packet q2_hi, q2_lo; + twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo); + // 1 + x * (C + x * p(x)) + Packet q3_hi, q3_lo; + // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum + // for adding it to unity here. + fast_twosum(one, q2_hi, q3_hi, q3_lo); + return padd(q3_hi, padd(q2_lo, q3_lo)); + } +}; + +// in [-0.5;0.5] with a relative accuracy of 1 ulp. +// The minimax polynomial used was calculated using the Sollya tool. +// See sollya.org. +template <> +struct fast_accurate_exp2 { + template + EIGEN_STRONG_INLINE + Packet operator()(const Packet& x) { + // This function approximates exp2(x) by a degree 10 polynomial of the form + // Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in + // single precision, and the remaining steps are evaluated with extra precision using + // double word arithmetic. C is an extra precise constant stored as a double word. + // + // The polynomial coefficients were calculated using Sollya commands: + // > n = 11; + // > f = 2^x; + // > interval = [-0.5;0.5]; + // > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating); + + const Packet p9 = pset1(4.431642109085495276e-10); + const Packet p8 = pset1(7.073829923303358410e-9); + const Packet p7 = pset1(1.017822306737031311e-7); + const Packet p6 = pset1(1.321543498017646657e-6); + const Packet p5 = pset1(1.525273342728892877e-5); + const Packet p4 = pset1(1.540353045780084423e-4); + const Packet p3 = pset1(1.333355814685869807e-3); + const Packet p2 = pset1(9.618129107593478832e-3); + const Packet p1 = pset1(5.550410866481961247e-2); + const Packet p0 = pset1(0.240226506959101332); + const Packet C_hi = pset1(0.693147180559945286); + const Packet C_lo = pset1(4.81927865669806721e-17); + const Packet one = pset1(1.0f); + + // Evaluate P(x) in working precision. + // We evaluate even and odd parts of the polynomial separately + // to gain some instruction level parallelism. + Packet x2 = pmul(x,x); + Packet p_even = pmadd(p8, x2, p6); + Packet p_odd = pmadd(p9, x2, p7); + p_even = pmadd(p_even, x2, p4); + p_odd = pmadd(p_odd, x2, p5); + p_even = pmadd(p_even, x2, p2); + p_odd = pmadd(p_odd, x2, p3); + p_even = pmadd(p_even, x2, p0); + p_odd = pmadd(p_odd, x2, p1); Packet p = pmadd(p_odd, x, p_even); // Evaluate the remaining terms of Q(x) with extra precision using @@ -1234,7 +1454,6 @@ EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) { // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy // using a specialized algorithm. Multiplication by the second factor can // be done exactly using pldexp(), since it is an integer power of 2. - // Packet e_r = fast_accurate_exp2()(r_z); const Packet e_r = fast_accurate_exp2()(r_z); return pldexp(e_r, n_z); } -- cgit v1.2.3