aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/CUDA
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2016-09-20 07:18:20 -0700
committerGravatar RJ Ryan <rjryan@google.com>2016-09-20 07:18:20 -0700
commitb2c6dc48d9189eb96f878aa6028aec245eadde85 (patch)
treed50f0abc9a8873616bea6c0a8a62c4a07fae7c10 /Eigen/src/Core/arch/CUDA
parent8a66ca4b100577e5a38082d47a1ffc0183574046 (diff)
Add CUDA-specific std::complex<T> specializations for scalar_sum_op, scalar_difference_op, scalar_product_op, and scalar_quotient_op.
Diffstat (limited to 'Eigen/src/Core/arch/CUDA')
-rw-r--r--Eigen/src/Core/arch/CUDA/Complex.h80
1 files changed, 80 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h
new file mode 100644
index 000000000..aa511a4b2
--- /dev/null
+++ b/Eigen/src/Core/arch/CUDA/Complex.h
@@ -0,0 +1,80 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_COMPLEX_CUDA_H
+#define EIGEN_COMPLEX_CUDA_H
+
+// clang-format off
+
+namespace Eigen {
+
+namespace internal {
+
+#if defined(__CUDACC__) && defined(EIGEN_USE_GPU)
+
+// Many std::complex methods such as operator+, operator-, operator* and
+// operator/ are not constexpr. Due to this, clang does not treat them as device
+// functions and thus Eigen functors making use of these operators fail to
+// compile. Here, we manually specialize these functors for complex types when
+// building for CUDA to avoid non-constexpr methods.
+
+template<typename T> struct scalar_sum_op<std::complex<T>> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
+ return std::complex<T>(numext::real(a) + numext::real(b),
+ numext::imag(a) + numext::imag(b));
+ }
+};
+
+template<typename T> struct scalar_difference_op<std::complex<T>> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
+ return std::complex<T>(numext::real(a) - numext::real(b),
+ numext::imag(a) - numext::imag(b));
+ }
+};
+
+template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T>> {
+ enum {
+ Vectorizable = packet_traits<std::complex<T>>::HasMul
+ };
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
+ const T a_real = numext::real(a);
+ const T a_imag = numext::imag(a);
+ const T b_real = numext::real(b);
+ const T b_imag = numext::imag(b);
+ return std::complex<T>(a_real * b_real - a_imag * b_imag,
+ a_real * b_imag + a_imag * b_real);
+ }
+};
+
+template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T>> {
+ enum {
+ Vectorizable = packet_traits<std::complex<T>>::HasDiv
+ };
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
+ const T a_real = numext::real(a);
+ const T a_imag = numext::imag(a);
+ const T b_real = numext::real(b);
+ const T b_imag = numext::imag(b);
+ const T norm = T(1) / (b_real * b_real + b_imag * b_imag);
+ return std::complex<T>((a_real * b_real + a_imag * b_imag) * norm,
+ (a_imag * b_real - a_real * b_imag) * norm);
+ }
+};
+
+#endif
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_COMPLEX_CUDA_H