diff options
author | RJ Ryan <rjryan@google.com> | 2016-09-20 07:18:20 -0700 |
---|---|---|
committer | RJ Ryan <rjryan@google.com> | 2016-09-20 07:18:20 -0700 |
commit | b2c6dc48d9189eb96f878aa6028aec245eadde85 (patch) | |
tree | d50f0abc9a8873616bea6c0a8a62c4a07fae7c10 /Eigen/src | |
parent | 8a66ca4b100577e5a38082d47a1ffc0183574046 (diff) |
Add CUDA-specific std::complex<T> specializations for scalar_sum_op, scalar_difference_op, scalar_product_op, and scalar_quotient_op.
Diffstat (limited to 'Eigen/src')
-rw-r--r-- | Eigen/src/Core/arch/CUDA/Complex.h | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h new file mode 100644 index 000000000..aa511a4b2 --- /dev/null +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -0,0 +1,80 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX_CUDA_H +#define EIGEN_COMPLEX_CUDA_H + +// clang-format off + +namespace Eigen { + +namespace internal { + +#if defined(__CUDACC__) && defined(EIGEN_USE_GPU) + +// Many std::complex methods such as operator+, operator-, operator* and +// operator/ are not constexpr. Due to this, clang does not treat them as device +// functions and thus Eigen functors making use of these operators fail to +// compile. Here, we manually specialize these functors for complex types when +// building for CUDA to avoid non-constexpr methods. + +template<typename T> struct scalar_sum_op<std::complex<T>> { + EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { + return std::complex<T>(numext::real(a) + numext::real(b), + numext::imag(a) + numext::imag(b)); + } +}; + +template<typename T> struct scalar_difference_op<std::complex<T>> { + EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { + return std::complex<T>(numext::real(a) - numext::real(b), + numext::imag(a) - numext::imag(b)); + } +}; + +template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T>> { + enum { + Vectorizable = packet_traits<std::complex<T>>::HasMul + }; + EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { + const T a_real = numext::real(a); + const T a_imag = numext::imag(a); + const T b_real = numext::real(b); + const T b_imag = numext::imag(b); + return std::complex<T>(a_real * b_real - a_imag * b_imag, + a_real * b_imag + a_imag * b_real); + } +}; + +template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T>> { + enum { + Vectorizable = packet_traits<std::complex<T>>::HasDiv + }; + EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { + const T a_real = numext::real(a); + const T a_imag = numext::imag(a); + const T b_real = numext::real(b); + const T b_imag = numext::imag(b); + const T norm = T(1) / (b_real * b_real + b_imag * b_imag); + return std::complex<T>((a_real * b_real + a_imag * b_imag) * norm, + (a_imag * b_real - a_real * b_imag) * norm); + } +}; + +#endif + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_COMPLEX_CUDA_H |