From 409e887d785012afdc4a4e661b9b78e8990e2623 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 3 Oct 2016 11:06:24 -0700 Subject: Added support for constand std::complex numbers on GPU --- Eigen/src/Core/arch/CUDA/Complex.h | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) (limited to 'Eigen/src/Core/arch/CUDA') diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h index f133b2db9..9c2536509 100644 --- a/Eigen/src/Core/arch/CUDA/Complex.h +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -24,34 +24,43 @@ namespace internal { // compile. Here, we manually specialize these functors for complex types when // building for CUDA to avoid non-constexpr methods. -template struct scalar_sum_op> { +// Sum +template struct scalar_sum_op, const std::complex > : binary_op_base, const std::complex > { typedef typename std::complex result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex operator() (const std::complex& a, const std::complex& b) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex operator() (const std::complex& a, const std::complex& b) const { return std::complex(numext::real(a) + numext::real(b), numext::imag(a) + numext::imag(b)); } }; -template struct scalar_difference_op> { +template struct scalar_sum_op, std::complex > : scalar_sum_op, const std::complex > {}; + + +// Difference +template struct scalar_difference_op, const std::complex > : binary_op_base, const std::complex > { typedef typename std::complex result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex operator() (const std::complex& a, const std::complex& b) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex operator() (const std::complex& a, const std::complex& b) const { return std::complex(numext::real(a) - numext::real(b), numext::imag(a) - numext::imag(b)); } }; -template struct scalar_product_op, std::complex> { +template struct scalar_difference_op, std::complex > : scalar_difference_op, const std::complex > {}; + + +// Product +template struct scalar_product_op, const std::complex > : binary_op_base, const std::complex > { enum { Vectorizable = packet_traits>::HasMul }; typedef typename std::complex result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex operator() (const std::complex& a, const std::complex& b) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex operator() (const std::complex& a, const std::complex& b) const { const T a_real = numext::real(a); const T a_imag = numext::imag(a); const T b_real = numext::real(b); @@ -61,14 +70,18 @@ template struct scalar_product_op, std::complex> } }; -template struct scalar_quotient_op, std::complex> { +template struct scalar_product_op, std::complex > : scalar_product_op, const std::complex > {}; + + +// Quotient +template struct scalar_quotient_op, const std::complex > : binary_op_base, const std::complex > { enum { Vectorizable = packet_traits>::HasDiv }; typedef typename std::complex result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex operator() (const std::complex& a, const std::complex& b) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex operator() (const std::complex& a, const std::complex& b) const { const T a_real = numext::real(a); const T a_imag = numext::imag(a); const T b_real = numext::real(b); @@ -79,6 +92,8 @@ template struct scalar_quotient_op, std::complex> } }; +template struct scalar_quotient_op, std::complex > : scalar_quotient_op, const std::complex > {}; + #endif } // end namespace internal -- cgit v1.2.3