From 070d303d56d46d2e018a58214da24ca629ea454f Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Tue, 22 Dec 2020 22:49:06 -0800 Subject: Add CUDA complex sqrt. This is to support scalar `sqrt` of complex numbers `std::complex` on device, requested by Tensorflow folks. Technically `std::complex` is not supported by NVCC on device (though it is by clang), so the default `sqrt(std::complex)` function only works on the host. Here we create an overload to add back the functionality. Also modified the CMake file to add `--relaxed-constexpr` (or equivalent) flag for NVCC to allow calling constexpr functions from device functions, and added support for specifying compute architecture for NVCC (was already available for clang). --- Eigen/src/Core/arch/CUDA/Complex.h | 55 +++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 6 deletions(-) (limited to 'Eigen/src/Core/arch/CUDA') diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h index 57d1201f4..69334cafe 100644 --- a/Eigen/src/Core/arch/CUDA/Complex.h +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -12,12 +12,12 @@ // clang-format off +#if defined(EIGEN_CUDACC) && defined(EIGEN_GPU_COMPILE_PHASE) + namespace Eigen { namespace internal { -#if defined(EIGEN_CUDACC) && defined(EIGEN_USE_GPU) - // Many std::complex methods such as operator+, operator-, operator* and // operator/ are not constexpr. Due to this, clang does not treat them as device // functions and thus Eigen functors making use of these operators fail to @@ -94,10 +94,53 @@ template struct scalar_quotient_op, const std: template struct scalar_quotient_op, std::complex > : scalar_quotient_op, const std::complex > {}; -#endif +template +struct sqrt_impl> { + static EIGEN_DEVICE_FUNC std::complex run(const std::complex& z) { + // Computes the principal sqrt of the input. + // + // For a complex square root of the number x + i*y. We want to find real + // numbers u and v such that + // (u + i*v)^2 = x + i*y <=> + // u^2 - v^2 + i*2*u*v = x + i*v. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x + // 2*u*v = y. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // v = y / (2 * u) + // and for x < 0, + // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) + // u = y / (2 * v) + // + // Letting w = sqrt(0.5 * (|x| + |z|)), + // if x == 0: u = w, v = sign(y) * w + // if x > 0: u = w, v = y / (2 * w) + // if x < 0: u = |y| / (2 * w), v = sign(y) * w + + const T x = numext::real(z); + const T y = numext::imag(z); + const T zero = T(0); + const T cst_half = T(0.5); + + // Special case of isinf(y) + if ((numext::isinf)(y)) { + const T inf = std::numeric_limits::infinity(); + return std::complex(inf, y); + } + + T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z))); + return + x == zero ? std::complex(w, y < zero ? -w : w) + : x > zero ? std::complex(w, y / (2 * w)) + : std::complex(numext::abs(y) / (2 * w), y < zero ? -w : w ); + } +}; -} // end namespace internal +} // namespace internal +} // namespace Eigen -} // end namespace Eigen +#endif -#endif // EIGEN_COMPLEX_CUDA_H +#endif // EIGEN_COMPLEX_CUDA_H -- cgit v1.2.3