aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-01-07 09:39:05 -0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-01-08 01:17:19 +0000
commitf149e0ebc3d3d5ca63234e58ca72690caf07e3b5 (patch)
tree8c5431fd057c96b8231be84b2908d130b49d61ec /Eigen/src
parent8d9cfba799ce3462c12568a36392e0abf36fc62d (diff)
Fix MSVC complex sqrt and packetmath test.
MSVC incorrectly handles `inf` cases for `std::sqrt<std::complex<T>>`. Here we replace it with a custom version (currently used on GPU). Also fixed the `packetmath` test, which previously skipped several corner cases since `CHECK_CWISE1` only tests the first `PacketSize` elements.
Diffstat (limited to 'Eigen/src')
-rw-r--r--Eigen/src/Core/MathFunctions.h16
-rw-r--r--Eigen/src/Core/MathFunctionsImpl.h44
-rw-r--r--Eigen/src/Core/arch/CUDA/Complex.h46
3 files changed, 66 insertions, 40 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index 928bc8e72..5b5ca46f6 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -338,6 +338,22 @@ struct sqrt_impl
}
};
+// Complex sqrt defined in MathFunctionsImpl.h.
+template<typename T> std::complex<T> complex_sqrt(const std::complex<T>& a_x);
+
+// MSVC incorrectly handles inf cases.
+#if EIGEN_COMP_MSVC > 0
+template<typename T>
+struct sqrt_impl<std::complex<T> >
+{
+ EIGEN_DEVICE_FUNC
+ static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
+ {
+ return complex_sqrt<T>(x);
+ }
+};
+#endif
+
template<typename Scalar>
struct sqrt_retval
{
diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h
index 8288ad834..8ecddebf6 100644
--- a/Eigen/src/Core/MathFunctionsImpl.h
+++ b/Eigen/src/Core/MathFunctionsImpl.h
@@ -99,6 +99,50 @@ struct hypot_impl
}
};
+// Generic complex sqrt implementation that correctly handles corner cases
+// according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt
+template<typename T>
+EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
+ // Computes the principal sqrt of the input.
+ //
+ // For a complex square root of the number x + i*y. We want to find real
+ // numbers u and v such that
+ // (u + i*v)^2 = x + i*y <=>
+ // u^2 - v^2 + i*2*u*v = x + i*v.
+ // By equating the real and imaginary parts we get:
+ // u^2 - v^2 = x
+ // 2*u*v = y.
+ //
+ // For x >= 0, this has the numerically stable solution
+ // u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
+ // v = y / (2 * u)
+ // and for x < 0,
+ // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
+ // u = y / (2 * v)
+ //
+ // Letting w = sqrt(0.5 * (|x| + |z|)),
+ // if x == 0: u = w, v = sign(y) * w
+ // if x > 0: u = w, v = y / (2 * w)
+ // if x < 0: u = |y| / (2 * w), v = sign(y) * w
+
+ const T x = numext::real(z);
+ const T y = numext::imag(z);
+ const T zero = T(0);
+ const T cst_half = T(0.5);
+
+ // Special case of isinf(y)
+ if ((numext::isinf)(y)) {
+ const T inf = std::numeric_limits<T>::infinity();
+ return std::complex<T>(inf, y);
+ }
+
+ T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z)));
+ return
+ x == zero ? std::complex<T>(w, y < zero ? -w : w)
+ : x > zero ? std::complex<T>(w, y / (2 * w))
+ : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h
index 69334cafe..ab0207cac 100644
--- a/Eigen/src/Core/arch/CUDA/Complex.h
+++ b/Eigen/src/Core/arch/CUDA/Complex.h
@@ -95,46 +95,12 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std:
template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
template<typename T>
-struct sqrt_impl<std::complex<T>> {
- static EIGEN_DEVICE_FUNC std::complex<T> run(const std::complex<T>& z) {
- // Computes the principal sqrt of the input.
- //
- // For a complex square root of the number x + i*y. We want to find real
- // numbers u and v such that
- // (u + i*v)^2 = x + i*y <=>
- // u^2 - v^2 + i*2*u*v = x + i*v.
- // By equating the real and imaginary parts we get:
- // u^2 - v^2 = x
- // 2*u*v = y.
- //
- // For x >= 0, this has the numerically stable solution
- // u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
- // v = y / (2 * u)
- // and for x < 0,
- // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
- // u = y / (2 * v)
- //
- // Letting w = sqrt(0.5 * (|x| + |z|)),
- // if x == 0: u = w, v = sign(y) * w
- // if x > 0: u = w, v = y / (2 * w)
- // if x < 0: u = |y| / (2 * w), v = sign(y) * w
-
- const T x = numext::real(z);
- const T y = numext::imag(z);
- const T zero = T(0);
- const T cst_half = T(0.5);
-
- // Special case of isinf(y)
- if ((numext::isinf)(y)) {
- const T inf = std::numeric_limits<T>::infinity();
- return std::complex<T>(inf, y);
- }
-
- T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z)));
- return
- x == zero ? std::complex<T>(w, y < zero ? -w : w)
- : x > zero ? std::complex<T>(w, y / (2 * w))
- : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
+struct sqrt_impl<std::complex<T> >
+{
+ EIGEN_DEVICE_FUNC
+ static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
+ {
+ return complex_sqrt<T>(x);
}
};