aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-01-18 13:25:16 +0000
committerGravatar David Tellenbach <david.tellenbach@me.com>2021-01-18 13:25:16 +0000
commitcdd8fdc32e730d5a65796a791ff13a92815c59b9 (patch)
tree3ee2ebf6295a44518d55ee3f6d9e26f4ca0a8a79 /Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
parentbde6741641b7c677d901cd48db844fcea1fd32fe (diff)
Vectorize `pow(x, y)`. This closes https://gitlab.com/libeigen/eigen/-/issues/2085, which also contains a description of the algorithm.
I ran some testing (comparing to `std::pow(double(x), double(y)))` for `x` in the set of all (positive) floats in the interval `[std::sqrt(std::numeric_limits<float>::min()), std::sqrt(std::numeric_limits<float>::max())]`, and `y` in `{2, sqrt(2), -sqrt(2)}` I get the following error statistics: ``` max_rel_error = 8.34405e-07 rms_rel_error = 2.76654e-07 ``` If I widen the range to all normal float I see lower accuracy for arguments where the result is subnormal, e.g. for `y = sqrt(2)`: ``` max_rel_error = 0.666667 rms = 6.8727e-05 count = 1335165689 argmax = 2.56049e-32, 2.10195e-45 != 1.4013e-45 ``` which seems reasonable, since these results are subnormals with only couple of significant bits left.
Diffstat (limited to 'Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h')
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h165
1 files changed, 165 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index f40093455..9a1feb0d9 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -793,6 +793,171 @@ Packet psqrt_complex(const Packet& a) {
pselect(is_real_inf, real_inf_result,result));
}
+
+// This function implements the Veltkamp splitting. Given a floating point
+// number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds
+// exactly and that half of the significant of x fits in x_hi.
+// This code corresponds to Algorithms 3 and 4 in
+// https://hal.inria.fr/hal-01774587v2/document
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2;
+ EIGEN_CONSTEXPR Scalar shift_scale = Scalar(uint64_t(1) << shift);
+ Packet gamma = pmul(pset1<Packet>(shift_scale + 1), x);
+#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+ x_hi = pmadd(pset1<Packet>(-shift_scale), x, gamma);
+#else
+ Packet rho = psub(x, gamma);
+ x_hi = padd(rho, gamma);
+#endif
+ x_lo = psub(x, x_hi);
+}
+
+// This function splits x into the nearest integer n and fractional part r,
+// such that x = n + r holds exactly.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void integer_split(const Packet& x, Packet& n, Packet& r) {
+ n = pround(x);
+ r = psub(x, n);
+}
+
+// This function implements Dekker's algorithm for two products {x * y1, x * y2} with
+// a shared factor. Given floating point numbers {x, y1, y2} computes the pairs
+// {p1, r1} and {p2, r2} such that x * y1 = p1 + r1 holds exactly and
+// p1 = fl(x * y1), and x * y2 = p2 + r2 holds exactly and p2 = fl(x * y2).
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void double_dekker(const Packet& x, const Packet& y1, const Packet& y2,
+ Packet& p1, Packet& r1, Packet& p2, Packet& r2) {
+ Packet x_hi, x_lo, y1_hi, y1_lo, y2_hi, y2_lo;
+ veltkamp_splitting(x, x_hi, x_lo);
+ veltkamp_splitting(y1, y1_hi, y1_lo);
+ veltkamp_splitting(y2, y2_hi, y2_lo);
+
+ p1 = pmul(x, y1);
+ r1 = pmadd(x_hi, y1_hi, pnegate(p1));
+ r1 = pmadd(x_hi, y1_lo, r1);
+ r1 = pmadd(x_lo, y1_hi, r1);
+ r1 = pmadd(x_lo, y1_lo, r1);
+
+ p2 = pmul(x, y2);
+ r2 = pmadd(x_hi, y2_hi, pnegate(p2));
+ r2 = pmadd(x_hi, y2_lo, r2);
+ r2 = pmadd(x_lo, y2_hi, r2);
+ r2 = pmadd(x_lo, y2_lo, r2);
+}
+
+// This function implements the non-trivial case of pow(x,y) where x is
+// positive and y is (possibly) non-integer.
+// Formally, pow(x,y) = 2**(y * log2(x))
+template<typename Packet>
+EIGEN_STRONG_INLINE
+Packet generic_pow_impl(const Packet& x, const Packet& y) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ // Split x into exponent e_x and mantissa m_x.
+ Packet e_x;
+ Packet m_x = pfrexp(x, e_x);
+
+ // Adjust m_x to lie in [0.75:1.5) to minimize absolute error in log2(m_x).
+ Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(Scalar(0.75)));
+ m_x = pselect(m_x_scale_mask, pmul(pset1<Packet>(Scalar(2)), m_x), m_x);
+ e_x = pselect(m_x_scale_mask, psub(e_x, pset1<Packet>(Scalar(1))), e_x);
+
+ Packet r_x = plog2(m_x);
+
+ // Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled
+ // precision using Dekker's algorithm.
+ Packet f1_hi, f1_lo, f2_hi, f2_lo;
+ double_dekker(y, e_x, r_x, f1_hi, f1_lo, f2_hi, f2_lo);
+
+ // Separate f into integer and fractional parts, keeping f1_hi, and f2_hi
+ // separate to avoid cancellation.
+ Packet n1, r1, n2, r2;
+ integer_split(f1_hi, n1, r1);
+ integer_split(f2_hi, n2, r2);
+
+ // Add up integer parts and sum the remainders.
+ Packet n_z = padd(n1, n2);
+ // Notice: I experimented with using compensated (Kahan) summation here,
+ // but it does not seem to matter.
+ Packet rem = padd(padd(f1_lo, f2_lo), padd(r1, r2));
+
+ // Extract any additional integer part that may have accumulated in rem.
+ Packet nrem, r_z;
+ integer_split(rem, nrem, r_z);
+ n_z = padd(n_z, nrem);
+
+ // We now have an accurate split of f = n_z + r_z and can compute
+ // x^y = 2**{n_z + r_z) = exp(ln(2) * r_z) * 2**{n_z}.
+ // The first factor we compute by calling pexp(), while multiplication
+ // by an integer power of 2 can be done exactly using pldexp().
+ // Note: I experimented with using Dekker's algorithms for the
+ // multiplication by ln(2) here, but did not see any difference.
+ Packet e_r = pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), r_z));
+ return pldexp(e_r, n_z);
+}
+
+// Generic implementation of pow(x,y).
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet generic_pow(const Packet& x, const Packet& y) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
+ const Packet cst_zero = pset1<Packet>(Scalar(0));
+ const Packet cst_one = pset1<Packet>(Scalar(1));
+ const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
+
+ Packet abs_x = pabs(x);
+ // Predicates for sign and magnitude of x.
+ Packet x_is_zero = pcmp_eq(x, cst_zero);
+ Packet x_is_neg = pcmp_lt(x, cst_zero);
+ Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
+ Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
+ Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
+ Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
+ Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
+ Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
+ Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
+
+ // Predicates for sign and magnitude of y.
+ Packet y_is_zero = pcmp_eq(y, cst_zero);
+ Packet y_is_neg = pcmp_lt(y, cst_zero);
+ Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg));
+ Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
+ Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf);
+
+ // Predicates for whether y is integer and/or even.
+ Packet y_is_int = pcmp_eq(pfloor(y), y);
+ Packet y_div_2 = pldexp(y, pset1<Packet>(Scalar(-1)));
+ Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
+
+ // Predicates encoding special cases for the value of pow(x,y)
+ Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), y_is_int), abs_y_is_inf);
+ Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
+ Packet pow_is_one = por(por(y_is_zero, x_is_one), pand(x_is_neg_one, abs_y_is_inf));
+ Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos), pand(abs_x_is_inf, y_is_neg)),
+ pand(pand(abs_x_is_lt_one, abs_y_is_inf), y_is_pos)),
+ pand(pand(abs_x_is_gt_one, abs_y_is_inf), y_is_neg));
+ Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)),
+ pand(pand(abs_x_is_lt_one, abs_y_is_inf), y_is_neg)),
+ pand(pand(abs_x_is_gt_one, abs_y_is_inf), y_is_pos));
+
+ // General computation of pow(x,y) for positive x or negative x and integer y.
+ Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
+ Packet pow_abs = generic_pow_impl(abs_x, y);
+
+ return pselect(pow_is_one, cst_one,
+ pselect(pow_is_nan, cst_nan,
+ pselect(pow_is_inf, cst_pos_inf,
+ pselect(pow_is_zero, cst_zero,
+ pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))));
+}
+
+
/* polevl (modified for Eigen)
*
* Evaluate polynomial