From bde6741641b7c677d901cd48db844fcea1fd32fe Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Sat, 16 Jan 2021 10:22:07 -0800 Subject: Improved std::complex sqrt and rsqrt. Replaces `std::sqrt` with `complex_sqrt` for all platforms (previously `complex_sqrt` was only used for CUDA and MSVC), and implements custom `complex_rsqrt`. Also introduces `numext::rsqrt` to simplify implementation, and modified `numext::hypot` to adhere to IEEE IEC 6059 for special cases. The `complex_sqrt` and `complex_rsqrt` implementations were found to be significantly faster than `std::sqrt>` and `1/numext::sqrt>`. Benchmark file attached. ``` GCC 10, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt> 9.21 ns 9.21 ns 73225448 BM_StdSqrt> 17.1 ns 17.1 ns 40966545 BM_Sqrt> 8.53 ns 8.53 ns 81111062 BM_StdSqrt> 21.5 ns 21.5 ns 32757248 BM_Rsqrt> 10.3 ns 10.3 ns 68047474 BM_DivSqrt> 16.3 ns 16.3 ns 42770127 BM_Rsqrt> 11.3 ns 11.3 ns 61322028 BM_DivSqrt> 16.5 ns 16.5 ns 42200711 Clang 11, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt> 7.46 ns 7.45 ns 90742042 BM_StdSqrt> 16.6 ns 16.6 ns 42369878 BM_Sqrt> 8.49 ns 8.49 ns 81629030 BM_StdSqrt> 21.8 ns 21.7 ns 31809588 BM_Rsqrt> 8.39 ns 8.39 ns 82933666 BM_DivSqrt> 14.4 ns 14.4 ns 48638676 BM_Rsqrt> 9.83 ns 9.82 ns 70068956 BM_DivSqrt> 15.7 ns 15.7 ns 44487798 Clang 9, Pixel 2, aarch64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt> 24.2 ns 24.1 ns 28616031 BM_StdSqrt> 104 ns 103 ns 6826926 BM_Sqrt> 31.8 ns 31.8 ns 22157591 BM_StdSqrt> 128 ns 128 ns 5437375 BM_Rsqrt> 31.9 ns 31.8 ns 22384383 BM_DivSqrt> 99.2 ns 98.9 ns 7250438 BM_Rsqrt> 46.0 ns 45.8 ns 15338689 BM_DivSqrt> 119 ns 119 ns 5898944 ``` --- Eigen/src/Core/MathFunctionsImpl.h | 60 ++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 9 deletions(-) (limited to 'Eigen/src/Core/MathFunctionsImpl.h') diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index 9222285b4..0d3f317bb 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -79,6 +79,12 @@ template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RealScalar positive_real_hypot(const RealScalar& x, const RealScalar& y) { + // IEEE IEC 6059 special cases. + if ((numext::isinf)(x) || (numext::isinf)(y)) + return NumTraits::infinity(); + if ((numext::isnan)(x) || (numext::isnan)(y)) + return NumTraits::quiet_NaN(); + EIGEN_USING_STD(sqrt); RealScalar p, qp; p = numext::maxi(x,y); @@ -128,20 +134,56 @@ EIGEN_DEVICE_FUNC std::complex complex_sqrt(const std::complex& z) { const T x = numext::real(z); const T y = numext::imag(z); const T zero = T(0); - const T cst_half = T(0.5); + const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y))); - // Special case of isinf(y) - if ((numext::isinf)(y)) { - return std::complex(std::numeric_limits::infinity(), y); - } - - T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z))); return - x == zero ? std::complex(w, y < zero ? -w : w) - : x > zero ? std::complex(w, y / (2 * w)) + (numext::isinf)(y) ? std::complex(NumTraits::infinity(), y) + : x == zero ? std::complex(w, y < zero ? -w : w) + : x > zero ? std::complex(w, y / (2 * w)) : std::complex(numext::abs(y) / (2 * w), y < zero ? -w : w ); } +// Generic complex rsqrt implementation. +template +EIGEN_DEVICE_FUNC std::complex complex_rsqrt(const std::complex& z) { + // Computes the principal reciprocal sqrt of the input. + // + // For a complex reciprocal square root of the number z = x + i*y. We want to + // find real numbers u and v such that + // (u + i*v)^2 = 1 / (x + i*y) <=> + // u^2 - v^2 + i*2*u*v = x/|z|^2 - i*v/|z|^2. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x/|z|^2 + // 2*u*v = y/|z|^2. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + |z|)) / |z| + // v = -y / (2 * u * |z|) + // and for x < 0, + // v = -sign(y) * sqrt(0.5 * (-x + |z|)) / |z| + // u = -y / (2 * v * |z|) + // + // Letting w = sqrt(0.5 * (|x| + |z|)), + // if x == 0: u = w / |z|, v = -sign(y) * w / |z| + // if x > 0: u = w / |z|, v = -y / (2 * w * |z|) + // if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z| + + const T x = numext::real(z); + const T y = numext::imag(z); + const T zero = T(0); + + const T abs_z = numext::hypot(x, y); + const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z)); + const T woz = w / abs_z; + // Corner cases consistent with 1/sqrt(z) on gcc/clang. + return + abs_z == zero ? std::complex(NumTraits::infinity(), NumTraits::quiet_NaN()) + : ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex(zero, zero) + : x == zero ? std::complex(woz, y < zero ? woz : -woz) + : x > zero ? std::complex(woz, -y / (2 * w * abs_z)) + : std::complex(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz ); +} + } // end namespace internal } // end namespace Eigen -- cgit v1.2.3