From 125cc9a5df6074756b89ea8aaa4e9a4b44b0f7e9 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Tue, 8 Dec 2020 18:13:35 -0800 Subject: Implement vectorized complex square root. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #1905 Measured speedup for sqrt of `complex` on Skylake: SSE: ``` name old time/op new time/op delta BM_eigen_sqrt_ctype/1 49.4ns ± 0% 54.3ns ± 0% +10.01% BM_eigen_sqrt_ctype/8 332ns ± 0% 50ns ± 1% -84.97% BM_eigen_sqrt_ctype/64 2.81µs ± 1% 0.38µs ± 0% -86.49% BM_eigen_sqrt_ctype/512 23.8µs ± 0% 3.0µs ± 0% -87.32% BM_eigen_sqrt_ctype/4k 202µs ± 0% 24µs ± 2% -88.03% BM_eigen_sqrt_ctype/32k 1.63ms ± 0% 0.19ms ± 0% -88.18% BM_eigen_sqrt_ctype/256k 13.0ms ± 0% 1.5ms ± 1% -88.20% BM_eigen_sqrt_ctype/1M 52.1ms ± 0% 6.2ms ± 0% -88.18% ``` AVX2: ``` name old cpu/op new cpu/op delta BM_eigen_sqrt_ctype/1 53.6ns ± 0% 55.6ns ± 0% +3.71% BM_eigen_sqrt_ctype/8 334ns ± 0% 27ns ± 0% -91.86% BM_eigen_sqrt_ctype/64 2.79µs ± 0% 0.22µs ± 2% -92.28% BM_eigen_sqrt_ctype/512 23.8µs ± 1% 1.7µs ± 1% -92.81% BM_eigen_sqrt_ctype/4k 201µs ± 0% 14µs ± 1% -93.24% BM_eigen_sqrt_ctype/32k 1.62ms ± 0% 0.11ms ± 1% -93.29% BM_eigen_sqrt_ctype/256k 13.0ms ± 0% 0.9ms ± 1% -93.31% BM_eigen_sqrt_ctype/1M 52.0ms ± 0% 3.5ms ± 1% -93.31% ``` AVX512: ``` name old cpu/op new cpu/op delta BM_eigen_sqrt_ctype/1 53.7ns ± 0% 56.2ns ± 1% +4.75% BM_eigen_sqrt_ctype/8 334ns ± 0% 18ns ± 2% -94.63% BM_eigen_sqrt_ctype/64 2.79µs ± 0% 0.12µs ± 1% -95.54% BM_eigen_sqrt_ctype/512 23.9µs ± 1% 1.0µs ± 1% -95.89% BM_eigen_sqrt_ctype/4k 202µs ± 0% 8µs ± 1% -96.13% BM_eigen_sqrt_ctype/32k 1.63ms ± 0% 0.06ms ± 1% -96.15% BM_eigen_sqrt_ctype/256k 13.0ms ± 0% 0.5ms ± 4% -96.11% BM_eigen_sqrt_ctype/1M 52.1ms ± 0% 2.0ms ± 1% -96.13% ``` --- .../Core/arch/Default/GenericPacketMathFunctions.h | 114 +++++++++++++++++++++ 1 file changed, 114 insertions(+) (limited to 'Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h') diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index c6bb89b05..45cc780f1 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -673,6 +673,120 @@ Packet pcos_float(const Packet& x) return psincos_float(x); } + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet psqrt_complex(const Packet& a) { + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + typedef typename unpacket_traits::as_real RealPacket; + + // Computes the principal sqrt of the complex numbers in the input. + // + // For example, for packets containing 2 complex numbers stored in interleaved format + // a = [a0, a1] = [x0, y0, x1, y1], + // where x0 = real(a0), y0 = imag(a0) etc., this function returns + // b = [b0, b1] = [u0, v0, u1, v1], + // such that b0^2 = a0, b1^2 = a1. + // + // To derive the formula for the complex square roots, let's consider the equation for + // a single complex square root of the number x + i*y. We want to find real numbers + // u and v such that + // (u + i*v)^2 = x + i*y <=> + // u^2 - v^2 + i*2*u*v = x + i*v. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x + // 2*u*v = y. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // v = 0.5 * (y / u) + // and for x < 0, + // v = sign(y) * sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // u = |0.5 * (y / v)| + // + // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as + // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) , + + // In the following, without lack of generality, we have annotated the code, assuming + // that the input is a packet of 2 complex numbers. + // + // Step 1. Compute l = [l0, l0, l1, l1], where + // l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2) + // To avoid over- and underflow, we use the stable formula for each hypotenuse + // l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)), + // where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1. + + Packet a_flip = pcplxflip(a); + RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|] + RealPacket a_abs_flip = pabs(a_flip.v); // [|y0|, |x0|, |y1|, |x1|] + RealPacket a_max = pmax(a_abs, a_abs_flip); + RealPacket a_min = pmin(a_abs, a_abs_flip); + RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); + RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); + RealPacket r = pdiv(a_min, a_max); + const RealPacket cst_one = pset1(RealScalar(1)); + RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] + // Set l to a_max if a_min is zero. + l = pselect(a_min_zero_mask, a_max, l); + + // Step 2. Compute [rho0, *, rho1, *], where + // rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|)) + // We don't care about the imaginary parts computed here. They will be overwritten later. + const RealPacket cst_half = pset1(RealScalar(0.5)); + Packet rho; + rho.v = psqrt(pmul(cst_half, padd(a_abs, l))); + + // Step 3. Compute [rho0, eta0, rho1, eta1], where + // eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2. + // set eta = 0 of input is 0 + i0. + RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask); + RealPacket real_mask = peven_mask(a.v); + Packet positive_real_result; + // Compute result for inputs with positive real part. + positive_real_result.v = pselect(real_mask, rho.v, eta); + + // Step 4. Compute solution for inputs with negative real part: + // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1] + const RealPacket cst_imag_sign_mask = pset1(Scalar(RealScalar(0.0), RealScalar(-0.0))).v; + RealPacket imag_signs = pand(a.v, cst_imag_sign_mask); + Packet negative_real_result; + // Notice that rho is positive, so taking it's absolute value is a noop. + negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs); + + // Step 5. Select solution branch based on the sign of the real parts. + Packet negative_real_mask; + negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v)); + negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v); + Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result); + + // Step 6. Handle special cases for infinities: + // * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN + // * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN + // * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y + // * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y + const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); + Packet is_inf; + is_inf.v = pcmp_eq(a_abs, cst_pos_inf); + Packet is_real_inf; + is_real_inf.v = pand(is_inf.v, real_mask); + is_real_inf = por(is_real_inf, pcplxflip(is_real_inf)); + // prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part. + Packet real_inf_result; + real_inf_result.v = pmul(a_abs, pset1(Scalar(RealScalar(1.0), RealScalar(0.0))).v); + real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v); + // prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part. + Packet is_imag_inf; + is_imag_inf.v = pandnot(is_inf.v, real_mask); + is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf)); + Packet imag_inf_result; + imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); + + return pselect(is_imag_inf, imag_inf_result, + pselect(is_real_inf, real_inf_result,result)); +} + /* polevl (modified for Eigen) * * Evaluate polynomial -- cgit v1.2.3