From bde6741641b7c677d901cd48db844fcea1fd32fe Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Sat, 16 Jan 2021 10:22:07 -0800 Subject: Improved std::complex sqrt and rsqrt. Replaces `std::sqrt` with `complex_sqrt` for all platforms (previously `complex_sqrt` was only used for CUDA and MSVC), and implements custom `complex_rsqrt`. Also introduces `numext::rsqrt` to simplify implementation, and modified `numext::hypot` to adhere to IEEE IEC 6059 for special cases. The `complex_sqrt` and `complex_rsqrt` implementations were found to be significantly faster than `std::sqrt>` and `1/numext::sqrt>`. Benchmark file attached. ``` GCC 10, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt> 9.21 ns 9.21 ns 73225448 BM_StdSqrt> 17.1 ns 17.1 ns 40966545 BM_Sqrt> 8.53 ns 8.53 ns 81111062 BM_StdSqrt> 21.5 ns 21.5 ns 32757248 BM_Rsqrt> 10.3 ns 10.3 ns 68047474 BM_DivSqrt> 16.3 ns 16.3 ns 42770127 BM_Rsqrt> 11.3 ns 11.3 ns 61322028 BM_DivSqrt> 16.5 ns 16.5 ns 42200711 Clang 11, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt> 7.46 ns 7.45 ns 90742042 BM_StdSqrt> 16.6 ns 16.6 ns 42369878 BM_Sqrt> 8.49 ns 8.49 ns 81629030 BM_StdSqrt> 21.8 ns 21.7 ns 31809588 BM_Rsqrt> 8.39 ns 8.39 ns 82933666 BM_DivSqrt> 14.4 ns 14.4 ns 48638676 BM_Rsqrt> 9.83 ns 9.82 ns 70068956 BM_DivSqrt> 15.7 ns 15.7 ns 44487798 Clang 9, Pixel 2, aarch64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt> 24.2 ns 24.1 ns 28616031 BM_StdSqrt> 104 ns 103 ns 6826926 BM_Sqrt> 31.8 ns 31.8 ns 22157591 BM_StdSqrt> 128 ns 128 ns 5437375 BM_Rsqrt> 31.9 ns 31.8 ns 22384383 BM_DivSqrt> 99.2 ns 98.9 ns 7250438 BM_Rsqrt> 46.0 ns 45.8 ns 15338689 BM_DivSqrt> 119 ns 119 ns 5898944 ``` --- test/stable_norm.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'test/stable_norm.cpp') diff --git a/test/stable_norm.cpp b/test/stable_norm.cpp index ee5f91674..008e35d87 100644 --- a/test/stable_norm.cpp +++ b/test/stable_norm.cpp @@ -161,8 +161,12 @@ template void stable_norm(const MatrixType& m) // mix { - Index i2 = internal::random(0,rows-1); - Index j2 = internal::random(0,cols-1); + // Ensure unique indices otherwise inf may be overwritten by NaN. + Index i2, j2; + do { + i2 = internal::random(0,rows-1); + j2 = internal::random(0,cols-1); + } while (i2 == i && j2 == j); v = vrand; v(i,j) = -std::numeric_limits::infinity(); v(i2,j2) = std::numeric_limits::quiet_NaN(); @@ -170,7 +174,8 @@ template void stable_norm(const MatrixType& m) VERIFY(!(numext::isfinite)(v.norm())); VERIFY((numext::isnan)(v.norm())); VERIFY(!(numext::isfinite)(v.stableNorm())); VERIFY((numext::isnan)(v.stableNorm())); VERIFY(!(numext::isfinite)(v.blueNorm())); VERIFY((numext::isnan)(v.blueNorm())); - VERIFY(!(numext::isfinite)(v.hypotNorm())); VERIFY((numext::isnan)(v.hypotNorm())); + // hypot propagates inf over NaN. + VERIFY(!(numext::isfinite)(v.hypotNorm())); VERIFY((numext::isinf)(v.hypotNorm())); } // stableNormalize[d] -- cgit v1.2.3