aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorRasmus Munk Larsen <rmlarsen@google.com>2021-07-02 13:36:05 -0700
committerRasmus Munk Larsen <rmlarsen@google.com>2021-07-02 20:42:15 +0000
commit7b35638ddb99a0298c5d3450de506a8e8e0203d3 (patch)
treee3c2b4c81330948aea0fe71816b739fef91fae6f
parentaab747021be5ed1a1e9667243d884eb72003599d (diff)
Fix breakage of conj_helper in conjunction with custom types introduced in !537.
-rw-r--r--Eigen/src/Core/arch/Default/ConjHelper.h68
1 files changed, 39 insertions, 29 deletions
diff --git a/Eigen/src/Core/arch/Default/ConjHelper.h b/Eigen/src/Core/arch/Default/ConjHelper.h
index 255daddc5..53830b5a2 100644
--- a/Eigen/src/Core/arch/Default/ConjHelper.h
+++ b/Eigen/src/Core/arch/Default/ConjHelper.h
@@ -57,48 +57,58 @@ template<> struct conj_if<false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& pconj(const T& x) const { return x; }
};
-// Generic implementation.
+// Generic Implementation, assume scalars since the packet-version is
+// specialized below.
template<typename LhsType, typename RhsType, bool ConjLhs, bool ConjRhs>
-struct conj_helper
-{
- typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType;
+struct conj_helper {
+ typedef typename ScalarBinaryOpTraits<LhsType, RhsType>::ReturnType ResultType;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
- { return Eigen::internal::pmadd(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y), c); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
+ { return this->pmul(x, y) + c; }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const
- { return Eigen::internal::pmul(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y)); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmul(const LhsType& x, const RhsType& y) const
+ { return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); }
};
-template<typename LhsType, typename RhsType>
-struct conj_helper<LhsType, RhsType, true, true>
-{
- typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType;
+template<typename LhsScalar, typename RhsScalar>
+struct conj_helper<LhsScalar, RhsScalar, true, true> {
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType ResultType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmadd(const LhsScalar& x, const RhsScalar& y, const ResultType& c) const
+ { return this->pmul(x, y) + c; }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
- { return Eigen::internal::pmadd(pconj(x), pconj(y), c); }
// We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b).
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const
- { return pconj(Eigen::internal::pmul(x, y)); }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmul(const LhsScalar& x, const RhsScalar& y) const
+ { return numext::conj(x * y); }
};
-// Generic implementation for mixed products of complex scalar types.
-template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
+// Implementation with equal type, use packet operations.
+template<typename Packet, bool ConjLhs, bool ConjRhs>
+struct conj_helper<Packet, Packet, ConjLhs, ConjRhs>
{
- typedef std::complex<RealScalar> Scalar;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
- { return c + conj_if<Conj>().pconj(x) * y; }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
- { return conj_if<Conj>().pconj(x) * y; }
+ typedef Packet ResultType;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const
+ { return Eigen::internal::pmadd(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y), c); }
+
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const
+ { return Eigen::internal::pmul(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y)); }
};
-template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
+template<typename Packet>
+struct conj_helper<Packet, Packet, true, true>
{
- typedef std::complex<RealScalar> Scalar;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
- { return c + pmul(x,y); }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
- { return x * conj_if<Conj>().pconj(y); }
+ typedef Packet ResultType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const
+ { return Eigen::internal::pmadd(pconj(x), pconj(y), c); }
+ // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b).
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const
+ { return pconj(Eigen::internal::pmul(x, y)); }
};
} // namespace internal