diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-10-03 11:06:24 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-10-03 11:06:24 -0700 |
commit | 409e887d785012afdc4a4e661b9b78e8990e2623 (patch) | |
tree | eacfc43a9d4dc08e5a878f6ae472e11d3ece12ca /Eigen/src/Core/arch/CUDA/Complex.h | |
parent | 9d6d0dff8f0c1e8630996c3a4867ff0599566b33 (diff) |
Added support for constand std::complex numbers on GPU
Diffstat (limited to 'Eigen/src/Core/arch/CUDA/Complex.h')
-rw-r--r-- | Eigen/src/Core/arch/CUDA/Complex.h | 31 |
1 files changed, 23 insertions, 8 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h index f133b2db9..9c2536509 100644 --- a/Eigen/src/Core/arch/CUDA/Complex.h +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -24,34 +24,43 @@ namespace internal { // compile. Here, we manually specialize these functors for complex types when // building for CUDA to avoid non-constexpr methods. -template<typename T> struct scalar_sum_op<std::complex<T>> { +// Sum +template<typename T> struct scalar_sum_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > { typedef typename std::complex<T> result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { return std::complex<T>(numext::real(a) + numext::real(b), numext::imag(a) + numext::imag(b)); } }; -template<typename T> struct scalar_difference_op<std::complex<T>> { +template<typename T> struct scalar_sum_op<std::complex<T>, std::complex<T> > : scalar_sum_op<const std::complex<T>, const std::complex<T> > {}; + + +// Difference +template<typename T> struct scalar_difference_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > { typedef typename std::complex<T> result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { return std::complex<T>(numext::real(a) - numext::real(b), numext::imag(a) - numext::imag(b)); } }; -template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T>> { +template<typename T> struct scalar_difference_op<std::complex<T>, std::complex<T> > : scalar_difference_op<const std::complex<T>, const std::complex<T> > {}; + + +// Product +template<typename T> struct scalar_product_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > { enum { Vectorizable = packet_traits<std::complex<T>>::HasMul }; typedef typename std::complex<T> result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { const T a_real = numext::real(a); const T a_imag = numext::imag(a); const T b_real = numext::real(b); @@ -61,14 +70,18 @@ template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T>> } }; -template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T>> { +template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T> > : scalar_product_op<const std::complex<T>, const std::complex<T> > {}; + + +// Quotient +template<typename T> struct scalar_quotient_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > { enum { Vectorizable = packet_traits<std::complex<T>>::HasDiv }; typedef typename std::complex<T> result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op) - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { const T a_real = numext::real(a); const T a_imag = numext::imag(a); const T b_real = numext::real(b); @@ -79,6 +92,8 @@ template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T>> } }; +template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {}; + #endif } // end namespace internal |