diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2021-07-02 13:36:05 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2021-07-02 20:42:15 +0000 |
commit | 7b35638ddb99a0298c5d3450de506a8e8e0203d3 (patch) | |
tree | e3c2b4c81330948aea0fe71816b739fef91fae6f /Eigen/src/Core/arch/Default | |
parent | aab747021be5ed1a1e9667243d884eb72003599d (diff) |
Fix breakage of conj_helper in conjunction with custom types introduced in !537.
Diffstat (limited to 'Eigen/src/Core/arch/Default')
-rw-r--r-- | Eigen/src/Core/arch/Default/ConjHelper.h | 68 |
1 files changed, 39 insertions, 29 deletions
diff --git a/Eigen/src/Core/arch/Default/ConjHelper.h b/Eigen/src/Core/arch/Default/ConjHelper.h index 255daddc5..53830b5a2 100644 --- a/Eigen/src/Core/arch/Default/ConjHelper.h +++ b/Eigen/src/Core/arch/Default/ConjHelper.h @@ -57,48 +57,58 @@ template<> struct conj_if<false> { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& pconj(const T& x) const { return x; } }; -// Generic implementation. +// Generic Implementation, assume scalars since the packet-version is +// specialized below. template<typename LhsType, typename RhsType, bool ConjLhs, bool ConjRhs> -struct conj_helper -{ - typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType; +struct conj_helper { + typedef typename ScalarBinaryOpTraits<LhsType, RhsType>::ReturnType ResultType; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const - { return Eigen::internal::pmadd(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y), c); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType + pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const + { return this->pmul(x, y) + c; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const - { return Eigen::internal::pmul(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y)); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType + pmul(const LhsType& x, const RhsType& y) const + { return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); } }; -template<typename LhsType, typename RhsType> -struct conj_helper<LhsType, RhsType, true, true> -{ - typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType; +template<typename LhsScalar, typename RhsScalar> +struct conj_helper<LhsScalar, RhsScalar, true, true> { + typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType ResultType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType + pmadd(const LhsScalar& x, const RhsScalar& y, const ResultType& c) const + { return this->pmul(x, y) + c; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const - { return Eigen::internal::pmadd(pconj(x), pconj(y), c); } // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b). - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const - { return pconj(Eigen::internal::pmul(x, y)); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType + pmul(const LhsScalar& x, const RhsScalar& y) const + { return numext::conj(x * y); } }; -// Generic implementation for mixed products of complex scalar types. -template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false> +// Implementation with equal type, use packet operations. +template<typename Packet, bool ConjLhs, bool ConjRhs> +struct conj_helper<Packet, Packet, ConjLhs, ConjRhs> { - typedef std::complex<RealScalar> Scalar; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const - { return c + conj_if<Conj>().pconj(x) * y; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const - { return conj_if<Conj>().pconj(x) * y; } + typedef Packet ResultType; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const + { return Eigen::internal::pmadd(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y), c); } + + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const + { return Eigen::internal::pmul(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y)); } }; -template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj> +template<typename Packet> +struct conj_helper<Packet, Packet, true, true> { - typedef std::complex<RealScalar> Scalar; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const - { return c + pmul(x,y); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const - { return x * conj_if<Conj>().pconj(y); } + typedef Packet ResultType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const + { return Eigen::internal::pmadd(pconj(x), pconj(y), c); } + // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b). + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const + { return pconj(Eigen::internal::pmul(x, y)); } }; } // namespace internal |