From 070d303d56d46d2e018a58214da24ca629ea454f Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Tue, 22 Dec 2020 22:49:06 -0800 Subject: Add CUDA complex sqrt. This is to support scalar `sqrt` of complex numbers `std::complex` on device, requested by Tensorflow folks. Technically `std::complex` is not supported by NVCC on device (though it is by clang), so the default `sqrt(std::complex)` function only works on the host. Here we create an overload to add back the functionality. Also modified the CMake file to add `--relaxed-constexpr` (or equivalent) flag for NVCC to allow calling constexpr functions from device functions, and added support for specifying compute architecture for NVCC (was already available for clang). --- Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h') diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index a6d2de62b..9253d8cab 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -703,8 +703,8 @@ Packet psqrt_complex(const Packet& a) { // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) // v = 0.5 * (y / u) // and for x < 0, - // v = sign(y) * sqrt(0.5 * (x + sqrt(x^2 + y^2))) - // u = |0.5 * (y / v)| + // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) + // u = 0.5 * (y / v) // // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) , -- cgit v1.2.3