diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-12-22 22:49:06 -0800 |
---|---|---|
committer | Antonio Sanchez <cantonios@google.com> | 2020-12-22 23:25:23 -0800 |
commit | 070d303d56d46d2e018a58214da24ca629ea454f (patch) | |
tree | 3dfa72bf48ffdca0a67bd794596e4e452d50ed19 /Eigen/src/Core/arch/CUDA | |
parent | fdf2ee62c5174441076fb64c9737d89bbe102759 (diff) |
Add CUDA complex sqrt.
This is to support scalar `sqrt` of complex numbers `std::complex<T>` on
device, requested by Tensorflow folks.
Technically `std::complex` is not supported by NVCC on device
(though it is by clang), so the default `sqrt(std::complex<T>)` function only
works on the host. Here we create an overload to add back the
functionality.
Also modified the CMake file to add `--relaxed-constexpr` (or
equivalent) flag for NVCC to allow calling constexpr functions from
device functions, and added support for specifying compute architecture for
NVCC (was already available for clang).
Diffstat (limited to 'Eigen/src/Core/arch/CUDA')
-rw-r--r-- | Eigen/src/Core/arch/CUDA/Complex.h | 55 |
1 files changed, 49 insertions, 6 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h index 57d1201f4..69334cafe 100644 --- a/Eigen/src/Core/arch/CUDA/Complex.h +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -12,12 +12,12 @@ // clang-format off +#if defined(EIGEN_CUDACC) && defined(EIGEN_GPU_COMPILE_PHASE) + namespace Eigen { namespace internal { -#if defined(EIGEN_CUDACC) && defined(EIGEN_USE_GPU) - // Many std::complex methods such as operator+, operator-, operator* and // operator/ are not constexpr. Due to this, clang does not treat them as device // functions and thus Eigen functors making use of these operators fail to @@ -94,10 +94,53 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std: template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {}; -#endif +template<typename T> +struct sqrt_impl<std::complex<T>> { + static EIGEN_DEVICE_FUNC std::complex<T> run(const std::complex<T>& z) { + // Computes the principal sqrt of the input. + // + // For a complex square root of the number x + i*y. We want to find real + // numbers u and v such that + // (u + i*v)^2 = x + i*y <=> + // u^2 - v^2 + i*2*u*v = x + i*v. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x + // 2*u*v = y. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // v = y / (2 * u) + // and for x < 0, + // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) + // u = y / (2 * v) + // + // Letting w = sqrt(0.5 * (|x| + |z|)), + // if x == 0: u = w, v = sign(y) * w + // if x > 0: u = w, v = y / (2 * w) + // if x < 0: u = |y| / (2 * w), v = sign(y) * w + + const T x = numext::real(z); + const T y = numext::imag(z); + const T zero = T(0); + const T cst_half = T(0.5); + + // Special case of isinf(y) + if ((numext::isinf)(y)) { + const T inf = std::numeric_limits<T>::infinity(); + return std::complex<T>(inf, y); + } + + T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z))); + return + x == zero ? std::complex<T>(w, y < zero ? -w : w) + : x > zero ? std::complex<T>(w, y / (2 * w)) + : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w ); + } +}; -} // end namespace internal +} // namespace internal +} // namespace Eigen -} // end namespace Eigen +#endif -#endif // EIGEN_COMPLEX_CUDA_H +#endif // EIGEN_COMPLEX_CUDA_H |