diff options
author | Antonio Sanchez <cantonios@google.com> | 2021-01-07 09:39:05 -0800 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2021-01-08 01:17:19 +0000 |
commit | f149e0ebc3d3d5ca63234e58ca72690caf07e3b5 (patch) | |
tree | 8c5431fd057c96b8231be84b2908d130b49d61ec | |
parent | 8d9cfba799ce3462c12568a36392e0abf36fc62d (diff) |
Fix MSVC complex sqrt and packetmath test.
MSVC incorrectly handles `inf` cases for `std::sqrt<std::complex<T>>`.
Here we replace it with a custom version (currently used on GPU).
Also fixed the `packetmath` test, which previously skipped several
corner cases since `CHECK_CWISE1` only tests the first `PacketSize`
elements.
-rw-r--r-- | Eigen/src/Core/MathFunctions.h | 16 | ||||
-rw-r--r-- | Eigen/src/Core/MathFunctionsImpl.h | 44 | ||||
-rw-r--r-- | Eigen/src/Core/arch/CUDA/Complex.h | 46 | ||||
-rw-r--r-- | test/packetmath.cpp | 14 | ||||
-rw-r--r-- | test/packetmath_test_shared.h | 11 |
5 files changed, 84 insertions, 47 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 928bc8e72..5b5ca46f6 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -338,6 +338,22 @@ struct sqrt_impl } }; +// Complex sqrt defined in MathFunctionsImpl.h. +template<typename T> std::complex<T> complex_sqrt(const std::complex<T>& a_x); + +// MSVC incorrectly handles inf cases. +#if EIGEN_COMP_MSVC > 0 +template<typename T> +struct sqrt_impl<std::complex<T> > +{ + EIGEN_DEVICE_FUNC + static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) + { + return complex_sqrt<T>(x); + } +}; +#endif + template<typename Scalar> struct sqrt_retval { diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index 8288ad834..8ecddebf6 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -99,6 +99,50 @@ struct hypot_impl } }; +// Generic complex sqrt implementation that correctly handles corner cases +// according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt +template<typename T> +EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) { + // Computes the principal sqrt of the input. + // + // For a complex square root of the number x + i*y. We want to find real + // numbers u and v such that + // (u + i*v)^2 = x + i*y <=> + // u^2 - v^2 + i*2*u*v = x + i*v. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x + // 2*u*v = y. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // v = y / (2 * u) + // and for x < 0, + // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) + // u = y / (2 * v) + // + // Letting w = sqrt(0.5 * (|x| + |z|)), + // if x == 0: u = w, v = sign(y) * w + // if x > 0: u = w, v = y / (2 * w) + // if x < 0: u = |y| / (2 * w), v = sign(y) * w + + const T x = numext::real(z); + const T y = numext::imag(z); + const T zero = T(0); + const T cst_half = T(0.5); + + // Special case of isinf(y) + if ((numext::isinf)(y)) { + const T inf = std::numeric_limits<T>::infinity(); + return std::complex<T>(inf, y); + } + + T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z))); + return + x == zero ? std::complex<T>(w, y < zero ? -w : w) + : x > zero ? std::complex<T>(w, y / (2 * w)) + : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w ); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h index 69334cafe..ab0207cac 100644 --- a/Eigen/src/Core/arch/CUDA/Complex.h +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -95,46 +95,12 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std: template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {}; template<typename T> -struct sqrt_impl<std::complex<T>> { - static EIGEN_DEVICE_FUNC std::complex<T> run(const std::complex<T>& z) { - // Computes the principal sqrt of the input. - // - // For a complex square root of the number x + i*y. We want to find real - // numbers u and v such that - // (u + i*v)^2 = x + i*y <=> - // u^2 - v^2 + i*2*u*v = x + i*v. - // By equating the real and imaginary parts we get: - // u^2 - v^2 = x - // 2*u*v = y. - // - // For x >= 0, this has the numerically stable solution - // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) - // v = y / (2 * u) - // and for x < 0, - // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) - // u = y / (2 * v) - // - // Letting w = sqrt(0.5 * (|x| + |z|)), - // if x == 0: u = w, v = sign(y) * w - // if x > 0: u = w, v = y / (2 * w) - // if x < 0: u = |y| / (2 * w), v = sign(y) * w - - const T x = numext::real(z); - const T y = numext::imag(z); - const T zero = T(0); - const T cst_half = T(0.5); - - // Special case of isinf(y) - if ((numext::isinf)(y)) { - const T inf = std::numeric_limits<T>::infinity(); - return std::complex<T>(inf, y); - } - - T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z))); - return - x == zero ? std::complex<T>(w, y < zero ? -w : w) - : x > zero ? std::complex<T>(w, y / (2 * w)) - : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w ); +struct sqrt_impl<std::complex<T> > +{ + EIGEN_DEVICE_FUNC + static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) + { + return complex_sqrt<T>(x); } }; diff --git a/test/packetmath.cpp b/test/packetmath.cpp index f19d72502..ab9bec183 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -933,7 +933,7 @@ void packetmath_complex() { for (int i = 0; i < size; ++i) { data1[i] = Scalar(internal::random<RealScalar>(), internal::random<RealScalar>()); } - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, size); // Test misc. corner cases. const RealScalar zero = RealScalar(0); @@ -944,32 +944,32 @@ void packetmath_complex() { data1[1] = Scalar(-zero, zero); data1[2] = Scalar(one, zero); data1[3] = Scalar(zero, one); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(-one, zero); data1[1] = Scalar(zero, -one); data1[2] = Scalar(one, one); data1[3] = Scalar(-one, -one); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(inf, zero); data1[1] = Scalar(zero, inf); data1[2] = Scalar(-inf, zero); data1[3] = Scalar(zero, -inf); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(inf, inf); data1[1] = Scalar(-inf, inf); data1[2] = Scalar(inf, -inf); data1[3] = Scalar(-inf, -inf); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(nan, zero); data1[1] = Scalar(zero, nan); data1[2] = Scalar(nan, one); data1[3] = Scalar(one, nan); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(nan, nan); data1[1] = Scalar(inf, nan); data1[2] = Scalar(nan, inf); data1[3] = Scalar(-inf, nan); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); } } diff --git a/test/packetmath_test_shared.h b/test/packetmath_test_shared.h index f8dc3711c..46a42604b 100644 --- a/test/packetmath_test_shared.h +++ b/test/packetmath_test_shared.h @@ -115,6 +115,17 @@ template<typename Scalar> bool areApprox(const Scalar* a, const Scalar* b, int s VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \ } +// Checks component-wise for input of size N. All of data1, data2, and ref +// should have size at least ceil(N/PacketSize)*PacketSize to avoid memory +// access errors. +#define CHECK_CWISE1_N(REFOP, POP, N) { \ + for (int i=0; i<N; ++i) \ + ref[i] = REFOP(data1[i]); \ + for (int j=0; j<N; j+=PacketSize) \ + internal::pstore(data2 + j, POP(internal::pload<Packet>(data1 + j))); \ + VERIFY(test::areApprox(ref, data2, N) && #POP); \ +} + template<bool Cond,typename Packet> struct packet_helper { |