aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-04-23 16:04:01 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-04-29 17:39:58 +0000
commit1c013be2cc6a999268be2f25575cd6a07bd52c45 (patch)
tree12f67dcbf72e6b747900d64fea7c3e572da3917f /Eigen
parent172db7bfc32def5ed0f885287e352b63dd5cd767 (diff)
Better CUDA complex division.
The original produced NaNs when dividing 0/b for subnormal b. The `complex_divide_stable` was changed to use the more common Smith's algorithm.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/Core/arch/CUDA/Complex.h27
1 files changed, 13 insertions, 14 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h
index b1618e567..deb4c8694 100644
--- a/Eigen/src/Core/arch/CUDA/Complex.h
+++ b/Eigen/src/Core/arch/CUDA/Complex.h
@@ -67,27 +67,26 @@ std::complex<T> complex_divide_fast(const std::complex<T>& a, const std::complex
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
- const T norm = T(1) / (b_real * b_real + b_imag * b_imag);
- return std::complex<T>((a_real * b_real + a_imag * b_imag) * norm,
- (a_imag * b_real - a_real * b_imag) * norm);
+ const T norm = (b_real * b_real + b_imag * b_imag);
+ return std::complex<T>((a_real * b_real + a_imag * b_imag) / norm,
+ (a_imag * b_real - a_real * b_imag) / norm);
}
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::complex<T> complex_divide_stable(const std::complex<T>& a, const std::complex<T>& b) {
+ const T a_real = numext::real(a);
+ const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
- // Guard against over/under-flow.
- const T scale = T(1) / (numext::abs(b_real) + numext::abs(b_imag));
- const T a_real_scaled = numext::real(a) * scale;
- const T a_imag_scaled = numext::imag(a) * scale;
- const T b_real_scaled = b_real * scale;
- const T b_imag_scaled = b_imag * scale;
-
- const T b_norm2_scaled = b_real_scaled * b_real_scaled + b_imag_scaled * b_imag_scaled;
- return std::complex<T>(
- (a_real_scaled * b_real_scaled + a_imag_scaled * b_imag_scaled) / b_norm2_scaled,
- (a_imag_scaled * b_real_scaled - a_real_scaled * b_imag_scaled) / b_norm2_scaled);
+ // Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf),
+ // guards against over/under-flow.
+ const bool scale_imag = numext::abs(b_imag) <= numext::abs(b_real);
+ const T rscale = scale_imag ? T(1) : b_real / b_imag;
+ const T iscale = scale_imag ? b_imag / b_real : T(1);
+ const T denominator = b_real * rscale + b_imag * iscale;
+ return std::complex<T>((a_real * rscale + a_imag * iscale) / denominator,
+ (a_imag * rscale - a_real * iscale) / denominator);
}
template<typename T>