aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/CUDA
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-12-22 22:49:06 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2020-12-22 23:25:23 -0800
commit070d303d56d46d2e018a58214da24ca629ea454f (patch)
tree3dfa72bf48ffdca0a67bd794596e4e452d50ed19 /Eigen/src/Core/arch/CUDA
parentfdf2ee62c5174441076fb64c9737d89bbe102759 (diff)
Add CUDA complex sqrt.
This is to support scalar `sqrt` of complex numbers `std::complex<T>` on device, requested by Tensorflow folks. Technically `std::complex` is not supported by NVCC on device (though it is by clang), so the default `sqrt(std::complex<T>)` function only works on the host. Here we create an overload to add back the functionality. Also modified the CMake file to add `--relaxed-constexpr` (or equivalent) flag for NVCC to allow calling constexpr functions from device functions, and added support for specifying compute architecture for NVCC (was already available for clang).
Diffstat (limited to 'Eigen/src/Core/arch/CUDA')
-rw-r--r--Eigen/src/Core/arch/CUDA/Complex.h55
1 files changed, 49 insertions, 6 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h
index 57d1201f4..69334cafe 100644
--- a/Eigen/src/Core/arch/CUDA/Complex.h
+++ b/Eigen/src/Core/arch/CUDA/Complex.h
@@ -12,12 +12,12 @@
// clang-format off
+#if defined(EIGEN_CUDACC) && defined(EIGEN_GPU_COMPILE_PHASE)
+
namespace Eigen {
namespace internal {
-#if defined(EIGEN_CUDACC) && defined(EIGEN_USE_GPU)
-
// Many std::complex methods such as operator+, operator-, operator* and
// operator/ are not constexpr. Due to this, clang does not treat them as device
// functions and thus Eigen functors making use of these operators fail to
@@ -94,10 +94,53 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std:
template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
-#endif
+template<typename T>
+struct sqrt_impl<std::complex<T>> {
+ static EIGEN_DEVICE_FUNC std::complex<T> run(const std::complex<T>& z) {
+ // Computes the principal sqrt of the input.
+ //
+ // For a complex square root of the number x + i*y. We want to find real
+ // numbers u and v such that
+ // (u + i*v)^2 = x + i*y <=>
+ // u^2 - v^2 + i*2*u*v = x + i*v.
+ // By equating the real and imaginary parts we get:
+ // u^2 - v^2 = x
+ // 2*u*v = y.
+ //
+ // For x >= 0, this has the numerically stable solution
+ // u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
+ // v = y / (2 * u)
+ // and for x < 0,
+ // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
+ // u = y / (2 * v)
+ //
+ // Letting w = sqrt(0.5 * (|x| + |z|)),
+ // if x == 0: u = w, v = sign(y) * w
+ // if x > 0: u = w, v = y / (2 * w)
+ // if x < 0: u = |y| / (2 * w), v = sign(y) * w
+
+ const T x = numext::real(z);
+ const T y = numext::imag(z);
+ const T zero = T(0);
+ const T cst_half = T(0.5);
+
+ // Special case of isinf(y)
+ if ((numext::isinf)(y)) {
+ const T inf = std::numeric_limits<T>::infinity();
+ return std::complex<T>(inf, y);
+ }
+
+ T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z)));
+ return
+ x == zero ? std::complex<T>(w, y < zero ? -w : w)
+ : x > zero ? std::complex<T>(w, y / (2 * w))
+ : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
+ }
+};
-} // end namespace internal
+} // namespace internal
+} // namespace Eigen
-} // end namespace Eigen
+#endif
-#endif // EIGEN_COMPLEX_CUDA_H
+#endif // EIGEN_COMPLEX_CUDA_H