aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/CUDA
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-10-03 11:06:24 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-10-03 11:06:24 -0700
commit409e887d785012afdc4a4e661b9b78e8990e2623 (patch)
treeeacfc43a9d4dc08e5a878f6ae472e11d3ece12ca /Eigen/src/Core/arch/CUDA
parent9d6d0dff8f0c1e8630996c3a4867ff0599566b33 (diff)
Added support for constand std::complex numbers on GPU
Diffstat (limited to 'Eigen/src/Core/arch/CUDA')
-rw-r--r--Eigen/src/Core/arch/CUDA/Complex.h31
1 files changed, 23 insertions, 8 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h
index f133b2db9..9c2536509 100644
--- a/Eigen/src/Core/arch/CUDA/Complex.h
+++ b/Eigen/src/Core/arch/CUDA/Complex.h
@@ -24,34 +24,43 @@ namespace internal {
// compile. Here, we manually specialize these functors for complex types when
// building for CUDA to avoid non-constexpr methods.
-template<typename T> struct scalar_sum_op<std::complex<T>> {
+// Sum
+template<typename T> struct scalar_sum_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
typedef typename std::complex<T> result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
return std::complex<T>(numext::real(a) + numext::real(b),
numext::imag(a) + numext::imag(b));
}
};
-template<typename T> struct scalar_difference_op<std::complex<T>> {
+template<typename T> struct scalar_sum_op<std::complex<T>, std::complex<T> > : scalar_sum_op<const std::complex<T>, const std::complex<T> > {};
+
+
+// Difference
+template<typename T> struct scalar_difference_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
typedef typename std::complex<T> result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
return std::complex<T>(numext::real(a) - numext::real(b),
numext::imag(a) - numext::imag(b));
}
};
-template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T>> {
+template<typename T> struct scalar_difference_op<std::complex<T>, std::complex<T> > : scalar_difference_op<const std::complex<T>, const std::complex<T> > {};
+
+
+// Product
+template<typename T> struct scalar_product_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
enum {
Vectorizable = packet_traits<std::complex<T>>::HasMul
};
typedef typename std::complex<T> result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
const T a_real = numext::real(a);
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
@@ -61,14 +70,18 @@ template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T>>
}
};
-template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T>> {
+template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T> > : scalar_product_op<const std::complex<T>, const std::complex<T> > {};
+
+
+// Quotient
+template<typename T> struct scalar_quotient_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
enum {
Vectorizable = packet_traits<std::complex<T>>::HasDiv
};
typedef typename std::complex<T> result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
const T a_real = numext::real(a);
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
@@ -79,6 +92,8 @@ template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T>>
}
};
+template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
+
#endif
} // end namespace internal