diff options
Diffstat (limited to 'Eigen/src/Core/util/BlasUtil.h')
-rwxr-xr-x | Eigen/src/Core/util/BlasUtil.h | 53 |
1 files changed, 41 insertions, 12 deletions
diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 498db3a70..6e6ee119b 100755 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -44,16 +44,29 @@ template<bool Conjugate> struct conj_if; template<> struct conj_if<true> { template<typename T> - inline T operator()(const T& x) { return numext::conj(x); } + inline T operator()(const T& x) const { return numext::conj(x); } template<typename T> - inline T pconj(const T& x) { return internal::pconj(x); } + inline T pconj(const T& x) const { return internal::pconj(x); } }; template<> struct conj_if<false> { template<typename T> - inline const T& operator()(const T& x) { return x; } + inline const T& operator()(const T& x) const { return x; } template<typename T> - inline const T& pconj(const T& x) { return x; } + inline const T& pconj(const T& x) const { return x; } +}; + +// Generic implementation for custom complex types. +template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs> +struct conj_helper +{ + typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar; + + EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar& x, const RhsScalar& y, const Scalar& c) const + { return padd(c, pmul(x,y)); } + + EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar& x, const RhsScalar& y) const + { return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); } }; template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false> @@ -111,7 +124,7 @@ template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::comp }; template<typename From,typename To> struct get_factor { - EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return x; } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); } }; template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> { @@ -135,7 +148,7 @@ class BlasVectorMapper { template <typename Packet> EIGEN_DEVICE_FUNC bool aligned(Index i) const { - return (size_t(m_data+i)%sizeof(Packet))==0; + return (UIntPtr(m_data+i)%sizeof(Packet))==0; } protected: @@ -227,7 +240,7 @@ class blas_data_mapper { EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; } EIGEN_DEVICE_FUNC Index firstAligned(Index size) const { - if (size_t(m_data)%sizeof(Scalar)) { + if (UIntPtr(m_data)%sizeof(Scalar)) { return -1; } return internal::first_default_aligned(m_data, size); @@ -293,17 +306,33 @@ struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> > }; // pop scalar multiple -template<typename Scalar, typename NestedXpr> -struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> > +template<typename Scalar, typename NestedXpr, typename Plain> +struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> > : blas_traits<NestedXpr> { typedef blas_traits<NestedXpr> Base; - typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType; + typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType; typedef typename Base::ExtractType ExtractType; - static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); } + static inline ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); } static inline Scalar extractScalarFactor(const XprType& x) - { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); } + { return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); } }; +template<typename Scalar, typename NestedXpr, typename Plain> +struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > > + : blas_traits<NestedXpr> +{ + typedef blas_traits<NestedXpr> Base; + typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType; + typedef typename Base::ExtractType ExtractType; + static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); } + static inline Scalar extractScalarFactor(const XprType& x) + { return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; } +}; +template<typename Scalar, typename Plain1, typename Plain2> +struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>, + const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > > + : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> > +{}; // pop opposite template<typename Scalar, typename NestedXpr> |