aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-17 02:50:32 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-17 02:50:32 +0000
commitbe0574e2159ce3d6a1748ba6060bea5dedccdbc9 (patch)
treeb4aa77c9f55ebb896dde3b3ee5d8acf6c139a6a8 /Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
parent7ff0b7a980ceffe7d0e72ebac924f514f7874e9b (diff)
New accurate algorithm for pow(x,y). This version is accurate to 1.4 ulps for float, while still being 10x faster than std::pow for AVX512. A future change will introduce a specialization for double.
Diffstat (limited to 'Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h')
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h512
1 files changed, 384 insertions, 128 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 42d310ab2..abdbcdbb9 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -1,3 +1,4 @@
+
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
@@ -36,7 +37,7 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
- EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
+ enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1};
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
}
@@ -46,28 +47,29 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic(const Packet& a, Packet& exponent) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
-
- EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
- EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
- EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
-
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
+
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
- ~(((ScalarUI(1) << exponent_bits) - ScalarUI(1)) << mantissa_bits); // ~0x7f800000
+ ~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
const Packet half = pset1<Packet>(Scalar(0.5));
const Packet zero = pzero(a);
const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
- // To handle denormals, normalize by multiplying by 2^(mantissa_bits+1).
+ // To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
- EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(mantissa_bits + 1); // 24
+ EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
// Determine exponent offset: -126 if normal, -126-24 if denormal
- const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(exponent_bits-1)) - ScalarUI(2)); // -126
+ const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
@@ -76,7 +78,7 @@ Packet pfrexp_generic(const Packet& a, Packet& exponent) {
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
- const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << exponent_bits) - ScalarUI(1)); // 255
+ const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
@@ -113,18 +115,20 @@ Packet pldexp_generic(const Packet& a, const Packet& exponent) {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<PacketI>::type ScalarI;
- EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
- EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
- EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
- const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<exponent_bits) + ScalarI(mantissa_bits - 1))); // 278
- const PacketI bias = pset1<PacketI>((ScalarI(1)<<(exponent_bits-1)) - ScalarI(1)); // 127
+ const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278
+ const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
- Packet c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^b
+ Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
b = psub(psub(psub(e, b), b), b); // e - 3b
- c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^(e-3*b)
+ c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
out = pmul(out, c);
return out;
}
@@ -890,112 +894,355 @@ Packet psqrt_complex(const Packet& a) {
pselect(is_real_inf, real_inf_result,result));
}
+// TODO(rmlarsen): The following set of utilities for double word arithmetic
+// should perhaps be refactored as a separate file, since it would be generally
+// useful for special function implementation etc. Writing the algorithms in
+// terms if a double word type would also make the code more readable.
+
+// This function splits x into the nearest integer n and fractional part r,
+// such that x = n + r holds exactly.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void absolute_split(const Packet& x, Packet& n, Packet& r) {
+ n = pround(x);
+ r = psub(x, n);
+}
+
+// This function computes the sum {s, r}, such that x + y = s_hi + s_lo
+// holds exactly, and s_hi = fl(x+y), if |x| >= |y|.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void fast_twosum(const Packet& x, const Packet& y, Packet& s_hi, Packet& s_lo) {
+ s_hi = padd(x, y);
+ const Packet t = psub(s_hi, x);
+ s_lo = psub(y, t);
+}
+
+#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+// This function implements the extended precision product of
+// a pair of floating point numbers. Given {x, y}, it computes the pair
+// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
+// p_hi = fl(x * y).
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void twoprod(const Packet& x, const Packet& y,
+ Packet& p_hi, Packet& p_lo) {
+ p_hi = pmul(x, y);
+ p_lo = pmadd(x, y, pnegate(p_hi));
+}
+
+#else
// This function implements the Veltkamp splitting. Given a floating point
// number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds
// exactly and that half of the significant of x fits in x_hi.
-// This code corresponds to Algorithms 3 and 4 in
-// https://hal.inria.fr/hal-01774587v2/document
+// This is Algorithm 3 from Jean-Michel Muller, "Elementary Functions",
+// 3rd edition, Birkh\"auser, 2016.
template<typename Packet>
EIGEN_STRONG_INLINE
void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) {
typedef typename unpacket_traits<Packet>::type Scalar;
EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2;
- Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr.
- Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x);
-#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
- x_hi = pmadd(pset1<Packet>(-shift_scale), x, gamma);
-#else
+ const Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr.
+ const Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x);
Packet rho = psub(x, gamma);
x_hi = padd(rho, gamma);
-#endif
x_lo = psub(x, x_hi);
}
-// This function splits x into the nearest integer n and fractional part r,
-// such that x = n + r holds exactly.
+// This function implements Dekker's algorithm for products x * y.
+// Given floating point numbers {x, y} computes the pair
+// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
+// p_hi = fl(x * y).
template<typename Packet>
EIGEN_STRONG_INLINE
-void integer_split(const Packet& x, Packet& n, Packet& r) {
- n = pround(x);
- r = psub(x, n);
+void twoprod(const Packet& x, const Packet& y,
+ Packet& p_hi, Packet& p_lo) {
+ Packet x_hi, x_lo, y_hi, y_lo;
+ veltkamp_splitting(x, x_hi, x_lo);
+ veltkamp_splitting(y, y_hi, y_lo);
+
+ p_hi = pmul(x, y);
+ p_lo = pmadd(x_hi, y_hi, pnegate(p_hi));
+ p_lo = pmadd(x_hi, y_lo, p_lo);
+ p_lo = pmadd(x_lo, y_hi, p_lo);
+ p_lo = pmadd(x_lo, y_lo, p_lo);
}
-// This function implements Dekker's algorithm for two products {x * y1, x * y2} with
-// a shared factor. Given floating point numbers {x, y1, y2} computes the pairs
-// {p1, r1} and {p2, r2} such that x * y1 = p1 + r1 holds exactly and
-// p1 = fl(x * y1), and x * y2 = p2 + r2 holds exactly and p2 = fl(x * y2).
+#endif // EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+
+
+// This function implements Dekker's algorithm for the addition
+// of two double word numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
+// It returns the result as a pair {s_hi, s_lo} such that
+// x_hi + x_lo + y_hi + y_lo = s_hi + s_lo holds exactly.
+// This is Algorithm 5 from Jean-Michel Muller, "Elementary Functions",
+// 3rd edition, Birkh\"auser, 2016.
template<typename Packet>
EIGEN_STRONG_INLINE
-void double_dekker(const Packet& x, const Packet& y1, const Packet& y2,
- Packet& p1, Packet& r1, Packet& p2, Packet& r2) {
- Packet x_hi, x_lo, y1_hi, y1_lo, y2_hi, y2_lo;
- veltkamp_splitting(x, x_hi, x_lo);
- veltkamp_splitting(y1, y1_hi, y1_lo);
- veltkamp_splitting(y2, y2_hi, y2_lo);
-
- p1 = pmul(x, y1);
- r1 = pmadd(x_hi, y1_hi, pnegate(p1));
- r1 = pmadd(x_hi, y1_lo, r1);
- r1 = pmadd(x_lo, y1_hi, r1);
- r1 = pmadd(x_lo, y1_lo, r1);
-
- p2 = pmul(x, y2);
- r2 = pmadd(x_hi, y2_hi, pnegate(p2));
- r2 = pmadd(x_hi, y2_lo, r2);
- r2 = pmadd(x_lo, y2_hi, r2);
- r2 = pmadd(x_lo, y2_lo, r2);
+ void twosum(const Packet& x_hi, const Packet& x_lo,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& s_hi, Packet& s_lo) {
+ const Packet x_greater_mask = pcmp_lt(pabs(y_hi), pabs(x_hi));
+ Packet r_hi_1, r_lo_1;
+ fast_twosum(x_hi, y_hi,r_hi_1, r_lo_1);
+ Packet r_hi_2, r_lo_2;
+ fast_twosum(y_hi, x_hi,r_hi_2, r_lo_2);
+ const Packet r_hi = pselect(x_greater_mask, r_hi_1, r_hi_2);
+
+ const Packet s1 = padd(padd(y_lo, r_lo_1), x_lo);
+ const Packet s2 = padd(padd(x_lo, r_lo_2), y_lo);
+ const Packet s = pselect(x_greater_mask, s1, s2);
+
+ fast_twosum(r_hi, s, s_hi, s_lo);
}
-// This function implements the non-trivial case of pow(x,y) where x is
-// positive and y is (possibly) non-integer.
-// Formally, pow(x,y) = 2**(y * log2(x))
+// This is a version of twosum for double word numbers,
+// which assumes that |x_hi| >= |y_hi|.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+ void fast_twosum(const Packet& x_hi, const Packet& x_lo,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& s_hi, Packet& s_lo) {
+ Packet r_hi, r_lo;
+ fast_twosum(x_hi, y_hi, r_hi, r_lo);
+ const Packet s = padd(padd(y_lo, r_lo), x_lo);
+ fast_twosum(r_hi, s, s_hi, s_lo);
+}
+
+// This function implements the multiplication of a double word
+// number represented by {x_hi, x_lo} by a floating point number y.
+// It returns the result as a pair {p_hi, p_lo} such that
+// (x_hi + x_lo) * y = p_hi + p_lo hold with a relative error
+// of less than 2*2^{-2p}, where p is the number of significand bit
+// in the floating point type.
+// This is Algorithm 7 from Jean-Michel Muller, "Elementary Functions",
+// 3rd edition, Birkh\"auser, 2016.
template<typename Packet>
EIGEN_STRONG_INLINE
-Packet generic_pow_impl(const Packet& x, const Packet& y) {
+void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y,
+ Packet& p_hi, Packet& p_lo) {
+ Packet c_hi, c_lo1;
+ twoprod(x_hi, y, c_hi, c_lo1);
+ const Packet c_lo2 = pmul(x_lo, y);
+ Packet t_hi, t_lo1;
+ fast_twosum(c_hi, c_lo2, t_hi, t_lo1);
+ const Packet t_lo2 = padd(t_lo1, c_lo1);
+ fast_twosum(t_hi, t_lo2, p_hi, p_lo);
+}
+
+// This function computes log2(x) and returns the result as a double word.
+template <typename Scalar>
+struct accurate_log2 {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
+ log2_x_hi = plog2(x);
+ log2_x_lo = pzero(x);
+ }
+};
+
+// This specialization uses a more accurate algorithm to compute log2(x) for
+// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10.
+// This additional accuracy is needed to counter the error-magnification
+// inherent in multiplying by a potentially large exponent in pow(x,y).
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+template <>
+struct accurate_log2<float> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) {
+ // The function log(1+x)/x is approximated in the interval
+ // [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form
+ // Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))),
+ // where the degree 6 polynomial P(x) is evaluated in single precision,
+ // while the remaining 4 terms of Q(x), as well as the final multiplication by x
+ // to reconstruct log(1+x) are evaluated in extra precision using
+ // double word arithmetic. C0 through C3 are extra precise constants
+ // stored as double words.
+ //
+ // The polynomial coefficients were calculated using Sollya commands:
+ // > n = 10;
+ // > f = log2(1+x)/x;
+ // > interval = [sqrt(0.5)-1;sqrt(2)-1];
+ // > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating);
+
+ const Packet p6 = pset1<Packet>( 9.703654795885e-2f);
+ const Packet p5 = pset1<Packet>(-0.1690667718648f);
+ const Packet p4 = pset1<Packet>( 0.1720575392246f);
+ const Packet p3 = pset1<Packet>(-0.1789081543684f);
+ const Packet p2 = pset1<Packet>( 0.2050433009862f);
+ const Packet p1 = pset1<Packet>(-0.2404672354459f);
+ const Packet p0 = pset1<Packet>( 0.2885761857032f);
+
+ const Packet C3_hi = pset1<Packet>(-0.360674142838f);
+ const Packet C3_lo = pset1<Packet>(-6.13283912543e-09f);
+ const Packet C2_hi = pset1<Packet>(0.480897903442f);
+ const Packet C2_lo = pset1<Packet>(-1.44861207474e-08f);
+ const Packet C1_hi = pset1<Packet>(-0.721347510815f);
+ const Packet C1_lo = pset1<Packet>(-4.84483164698e-09f);
+ const Packet C0_hi = pset1<Packet>(1.44269502163f);
+ const Packet C0_lo = pset1<Packet>(2.01711713999e-08f);
+ const Packet one = pset1<Packet>(1.0f);
+
+ const Packet x = psub(z, one);
+ // Evaluate P(x) in working precision.
+ // We evaluate it in multiple parts to improve instruction level
+ // parallelism.
+ Packet x2 = pmul(x,x);
+ Packet p_even = pmadd(p6, x2, p4);
+ p_even = pmadd(p_even, x2, p2);
+ p_even = pmadd(p_even, x2, p0);
+ Packet p_odd = pmadd(p5, x2, p3);
+ p_odd = pmadd(p_odd, x2, p1);
+ Packet p = pmadd(p_odd, x, p_even);
+
+ // Now evaluate the low-order tems of Q(x) in double word precision.
+ // In the following, due to the alternating signs and the fact that
+ // |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use
+ // fast_twosum instead of the slower twosum.
+ Packet q_hi, q_lo;
+ Packet t_hi, t_lo;
+ // C3 + x * p(x)
+ twoprod(p, x, t_hi, t_lo);
+ fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo);
+ // C2 + x * p(x)
+ twoprod(q_hi, q_lo, x, t_hi, t_lo);
+ fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo);
+ // C1 + x * p(x)
+ twoprod(q_hi, q_lo, x, t_hi, t_lo);
+ fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo);
+ // C0 + x * p(x)
+ twoprod(q_hi, q_lo, x, t_hi, t_lo);
+ fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo);
+
+ // log(z) ~= x * Q(x)
+ twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo);
+ }
+};
+
+// This function computes exp2(x) (i.e. 2**x).
+template <typename Scalar>
+struct fast_accurate_exp2 {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ Packet operator()(const Packet& x) {
+ // TODO(rmlarsen): Add a pexp2 packetop.
+ return pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), x));
+ }
+};
+
+// This specialization uses a faster algorithm to compute exp2(x) for floats
+// in [-0.5;0.5] with a relative accuracy of 1 ulp.
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+template <>
+struct fast_accurate_exp2<float> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ Packet operator()(const Packet& x) {
+ // This function approximates exp2(x) by a degree 6 polynomial of the form
+ // Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in
+ // single precision, and the remaining steps are evaluated with extra precision using
+ // double word arithmetic. C is an extra precise constant stored as a double word.
+ //
+ // The polynomial coefficients were calculated using Sollya commands:
+ // > n = 6;
+ // > f = 2^x;
+ // > interval = [-0.5;0.5];
+ // > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating);
+
+ const Packet p4 = pset1<Packet>(1.539513905e-4f);
+ const Packet p3 = pset1<Packet>(1.340007293e-3f);
+ const Packet p2 = pset1<Packet>(9.618283249e-3f);
+ const Packet p1 = pset1<Packet>(5.550328270e-2f);
+ const Packet p0 = pset1<Packet>(0.2402264923f);
+
+ const Packet C_hi = pset1<Packet>(0.6931471825f);
+ const Packet C_lo = pset1<Packet>(2.36836577e-08f);
+ const Packet one = pset1<Packet>(1.0f);
+
+ // Evaluate P(x) in working precision.
+ // We evaluate even and odd parts of the polynomial separately
+ // to gain some instruction level parallelism.
+ Packet x2 = pmul(x,x);
+ Packet p_even = pmadd(p4, x2, p2);
+ p_even = pmadd(p_even, x2, p0);
+ Packet p_odd = pmadd(p3, x2, p1);
+ Packet p = pmadd(p_odd, x, p_even);
+
+ // Evaluate the remaining terms of Q(x) with extra precision using
+ // double word arithmetic.
+ Packet p_hi, p_lo;
+ // x * p(x)
+ twoprod(p, x, p_hi, p_lo);
+ // C + x * p(x)
+ Packet q1_hi, q1_lo;
+ twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
+ // x * (C + x * p(x))
+ Packet q2_hi, q2_lo;
+ twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
+ // 1 + x * (C + x * p(x))
+ Packet q3_hi, q3_lo;
+ // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
+ // for adding it to unity here.
+ fast_twosum(one, q2_hi, q3_hi, q3_lo);
+ return padd(q3_hi, padd(q2_lo, q3_lo));
+ }
+};
+
+// This function implements the non-trivial case of pow(x,y) where x is
+// positive and y is (possibly) non-integer.
+// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x.
+// TODO(rmlarsen): We should probably add this as a packet up 'ppow', to make it
+// easier to specialize or turn off for specific types and/or backends.x
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) {
typedef typename unpacket_traits<Packet>::type Scalar;
// Split x into exponent e_x and mantissa m_x.
Packet e_x;
Packet m_x = pfrexp(x, e_x);
- // Adjust m_x to lie in [0.75:1.5) to minimize absolute error in log2(m_x).
- Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(Scalar(0.75)));
+ // Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x).
+ EIGEN_CONSTEXPR Scalar sqrt_half = Scalar(0.70710678118654752440);
+ const Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(sqrt_half));
m_x = pselect(m_x_scale_mask, pmul(pset1<Packet>(Scalar(2)), m_x), m_x);
e_x = pselect(m_x_scale_mask, psub(e_x, pset1<Packet>(Scalar(1))), e_x);
- Packet r_x = plog2(m_x);
+ // Compute log2(m_x) with 6 extra bits of accuracy.
+ Packet rx_hi, rx_lo;
+ accurate_log2<Scalar>()(m_x, rx_hi, rx_lo);
// Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled
- // precision using Dekker's algorithm.
+ // precision using double word arithmetic.
Packet f1_hi, f1_lo, f2_hi, f2_lo;
- double_dekker(y, e_x, r_x, f1_hi, f1_lo, f2_hi, f2_lo);
-
- // Separate f into integer and fractional parts, keeping f1_hi, and f2_hi
- // separate to avoid cancellation.
- Packet n1, r1, n2, r2;
- integer_split(f1_hi, n1, r1);
- integer_split(f2_hi, n2, r2);
-
- // Add up integer parts and sum the remainders.
- Packet n_z = padd(n1, n2);
- // Notice: I experimented with using compensated (Kahan) summation here,
- // but it does not seem to matter.
- Packet rem = padd(padd(f1_lo, f2_lo), padd(r1, r2));
-
- // Extract any additional integer part that may have accumulated in rem.
- Packet nrem, r_z;
- integer_split(rem, nrem, r_z);
- n_z = padd(n_z, nrem);
+ twoprod(e_x, y, f1_hi, f1_lo);
+ twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo);
+ // Sum the two terms in f using double word arithmetic. We know
+ // that |e_x| > |log2(m_x)|, except for the case where e_x==0.
+ // This means that we can use fast_twosum(f1,f2).
+ // In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any
+ // accuracy by violating the assumption of fast_twosum, because
+ // it's a no-op.
+ Packet f_hi, f_lo;
+ fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo);
+
+ // Split f into integer and fractional parts.
+ Packet n_z, r_z;
+ absolute_split(f_hi, n_z, r_z);
+ r_z = padd(r_z, f_lo);
+ Packet n_r;
+ absolute_split(r_z, n_r, r_z);
+ n_z = padd(n_z, n_r);
// We now have an accurate split of f = n_z + r_z and can compute
- // x^y = 2**{n_z + r_z) = exp(ln(2) * r_z) * 2**{n_z}.
- // The first factor we compute by calling pexp(), while multiplication
- // by an integer power of 2 can be done exactly using pldexp().
- // Note: I experimented with using Dekker's algorithms for the
- // multiplication by ln(2) here, but did not see any difference.
- Packet e_r = pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), r_z));
- // TODO: investigate bounds of e_r and n_z, potentially using faster
- // implementation of ldexp.
+ // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}.
+ // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy
+ // using a specialized algorithm. Multiplication by the second factor can
+ // be done exactly using pldexp(), since it is an integer power of 2.
+ // Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
+ const Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
return pldexp(e_r, n_z);
}
@@ -1005,66 +1252,75 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet generic_pow(const Packet& x, const Packet& y) {
typedef typename unpacket_traits<Packet>::type Scalar;
+
const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
const Packet cst_zero = pset1<Packet>(Scalar(0));
const Packet cst_one = pset1<Packet>(Scalar(1));
- const Packet cst_half = pset1<Packet>(Scalar(0.5));
const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
- Packet abs_x = pabs(x);
+ const Packet abs_x = pabs(x);
// Predicates for sign and magnitude of x.
- Packet x_is_zero = pcmp_eq(x, cst_zero);
- Packet x_is_neg = pcmp_lt(x, cst_zero);
- Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
- Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
- Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
- Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
- Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
- Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
- Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
+ const Packet x_is_zero = pcmp_eq(x, cst_zero);
+ const Packet x_is_neg = pcmp_lt(x, cst_zero);
+ const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
+ const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
+ const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
+ const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
+ const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
+ const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
+ const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
// Predicates for sign and magnitude of y.
- Packet y_is_zero = pcmp_eq(y, cst_zero);
- Packet y_is_neg = pcmp_lt(y, cst_zero);
- Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg));
- Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
- Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf);
- // |y| is so large that (1+eps)^y over- or underflows.
+ const Packet y_is_one = pcmp_eq(y, cst_one);
+ const Packet y_is_zero = pcmp_eq(y, cst_zero);
+ const Packet y_is_neg = pcmp_lt(y, cst_zero);
+ const Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg));
+ const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
+ const Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf);
EIGEN_CONSTEXPR Scalar huge_exponent =
- (std::numeric_limits<Scalar>::digits * Scalar(EIGEN_LOG2E)) /
+ (std::numeric_limits<Scalar>::max_exponent * Scalar(EIGEN_LN2)) /
std::numeric_limits<Scalar>::epsilon();
- Packet abs_y_is_huge = pcmp_lt(pset1<Packet>(huge_exponent), pabs(y));
+ const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y));
// Predicates for whether y is integer and/or even.
- Packet y_is_int = pcmp_eq(pfloor(y), y);
- Packet y_div_2 = pmul(y, cst_half);
- Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
+ const Packet y_is_int = pcmp_eq(pfloor(y), y);
+ const Packet y_div_2 = pmul(y, pset1<Packet>(Scalar(0.5)));
+ const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
// Predicates encoding special cases for the value of pow(x,y)
- Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), y_is_int), abs_y_is_inf);
- Packet pow_is_one = por(por(x_is_one, y_is_zero),
- pand(x_is_neg_one,
- por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x))));
- Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
- Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos), pand(abs_x_is_inf, y_is_neg)),
- pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_pos)),
- pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_neg));
- Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)),
- pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_neg)),
- pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_pos));
+ const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf),
+ y_is_int),
+ abs_y_is_inf);
+ const Packet pow_is_one = por(por(x_is_one, y_is_zero),
+ pand(x_is_neg_one,
+ por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x))));
+ const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
+ const Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos),
+ pand(abs_x_is_inf, y_is_neg)),
+ pand(pand(abs_x_is_lt_one, abs_y_is_huge),
+ y_is_pos)),
+ pand(pand(abs_x_is_gt_one, abs_y_is_huge),
+ y_is_neg));
+ const Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg),
+ pand(abs_x_is_inf, y_is_pos)),
+ pand(pand(abs_x_is_lt_one, abs_y_is_huge),
+ y_is_neg)),
+ pand(pand(abs_x_is_gt_one, abs_y_is_huge),
+ y_is_pos));
// General computation of pow(x,y) for positive x or negative x and integer y.
- Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
- Packet pow_abs = generic_pow_impl(abs_x, y);
-
- return pselect(pow_is_one, cst_one,
- pselect(pow_is_nan, cst_nan,
- pselect(pow_is_inf, cst_pos_inf,
- pselect(pow_is_zero, cst_zero,
- pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))));
+ const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
+ const Packet pow_abs = generic_pow_impl(abs_x, y);
+ return pselect(y_is_one, x,
+ pselect(pow_is_one, cst_one,
+ pselect(pow_is_nan, cst_nan,
+ pselect(pow_is_inf, cst_pos_inf,
+ pselect(pow_is_zero, cst_zero,
+ pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))));
}
+
/* polevl (modified for Eigen)
*
* Evaluate polynomial