diff options
author | Antonio Sanchez <cantonios@google.com> | 2021-01-07 09:39:05 -0800 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2021-01-08 01:17:19 +0000 |
commit | f149e0ebc3d3d5ca63234e58ca72690caf07e3b5 (patch) | |
tree | 8c5431fd057c96b8231be84b2908d130b49d61ec /Eigen/src/Core/arch/CUDA/Complex.h | |
parent | 8d9cfba799ce3462c12568a36392e0abf36fc62d (diff) |
Fix MSVC complex sqrt and packetmath test.
MSVC incorrectly handles `inf` cases for `std::sqrt<std::complex<T>>`.
Here we replace it with a custom version (currently used on GPU).
Also fixed the `packetmath` test, which previously skipped several
corner cases since `CHECK_CWISE1` only tests the first `PacketSize`
elements.
Diffstat (limited to 'Eigen/src/Core/arch/CUDA/Complex.h')
-rw-r--r-- | Eigen/src/Core/arch/CUDA/Complex.h | 46 |
1 files changed, 6 insertions, 40 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h index 69334cafe..ab0207cac 100644 --- a/Eigen/src/Core/arch/CUDA/Complex.h +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -95,46 +95,12 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std: template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {}; template<typename T> -struct sqrt_impl<std::complex<T>> { - static EIGEN_DEVICE_FUNC std::complex<T> run(const std::complex<T>& z) { - // Computes the principal sqrt of the input. - // - // For a complex square root of the number x + i*y. We want to find real - // numbers u and v such that - // (u + i*v)^2 = x + i*y <=> - // u^2 - v^2 + i*2*u*v = x + i*v. - // By equating the real and imaginary parts we get: - // u^2 - v^2 = x - // 2*u*v = y. - // - // For x >= 0, this has the numerically stable solution - // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) - // v = y / (2 * u) - // and for x < 0, - // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) - // u = y / (2 * v) - // - // Letting w = sqrt(0.5 * (|x| + |z|)), - // if x == 0: u = w, v = sign(y) * w - // if x > 0: u = w, v = y / (2 * w) - // if x < 0: u = |y| / (2 * w), v = sign(y) * w - - const T x = numext::real(z); - const T y = numext::imag(z); - const T zero = T(0); - const T cst_half = T(0.5); - - // Special case of isinf(y) - if ((numext::isinf)(y)) { - const T inf = std::numeric_limits<T>::infinity(); - return std::complex<T>(inf, y); - } - - T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z))); - return - x == zero ? std::complex<T>(w, y < zero ? -w : w) - : x > zero ? std::complex<T>(w, y / (2 * w)) - : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w ); +struct sqrt_impl<std::complex<T> > +{ + EIGEN_DEVICE_FUNC + static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) + { + return complex_sqrt<T>(x); } }; |