diff options
author | Gael Guennebaud <g.gael@free.fr> | 2009-07-07 16:55:51 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2009-07-07 16:55:51 +0200 |
commit | 5ed6ce90d3d626e86127961f0845570223ac9c0b (patch) | |
tree | 61b005e40183e9b67ed6cc9825aa4363af10d8f0 /Eigen/src/Core | |
parent | ea23f36c7843854cfcfc3fbfec4b65c935e56456 (diff) |
started to catch scalar multiple and conjugate xpr in Product
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/CwiseUnaryOp.h | 6 | ||||
-rw-r--r-- | Eigen/src/Core/Product.h | 53 |
2 files changed, 59 insertions, 0 deletions
diff --git a/Eigen/src/Core/CwiseUnaryOp.h b/Eigen/src/Core/CwiseUnaryOp.h index a36a629db..0095a1572 100644 --- a/Eigen/src/Core/CwiseUnaryOp.h +++ b/Eigen/src/Core/CwiseUnaryOp.h @@ -92,6 +92,12 @@ class CwiseUnaryOp : ei_no_assignment_operator, return m_functor.packetOp(m_matrix.template packet<LoadMode>(index)); } + /** \internal used for introspection */ + const UnaryOp& _functor() const { return m_functor; } + + /** \internal used for introspection */ + const typename MatrixType::Nested& _expression() const { return m_matrix; } + protected: const typename MatrixType::Nested m_matrix; const UnaryOp m_functor; diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 0652eb615..6849d90e3 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -92,6 +92,50 @@ template<typename Lhs, typename Rhs> struct ei_product_mode : NormalProduct }; }; +template<typename XprType> struct ei_product_factor_traits +{ + typedef typename ei_traits<XprType>::Scalar Scalar; + typedef XprType RealXprType; + enum { + IsComplex = NumTraits<Scalar>::IsComplex, + NeedToConjugate = false, + HasScalarMultiple = false, + Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess + }; + static inline const RealXprType& extract(const XprType& x) { return x; } + static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); } +}; + +// pop conjugate +template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> > + : ei_product_factor_traits<NestedXpr> +{ + typedef ei_product_factor_traits<NestedXpr> Base; + typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType; + typedef typename Base::RealXprType RealXprType; + + enum { + IsComplex = NumTraits<Scalar>::IsComplex, + NeedToConjugate = IsComplex + }; + static inline const RealXprType& extract(const XprType& x) { return x._expression(); } + static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); } +}; + +// pop scalar multiple +template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> > + : ei_product_factor_traits<NestedXpr> +{ + typedef ei_product_factor_traits<NestedXpr> Base; + typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType; + typedef typename Base::RealXprType RealXprType; + enum { + HasScalarMultiple = true + }; + static inline const RealXprType& extract(const XprType& x) { return x._expression(); } + static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().value; } +}; + /** \class Product * * \brief Expression of the product of two matrices @@ -517,6 +561,15 @@ template<typename Scalar, typename ResType> static void ei_cache_friendly_product_rowmajor_times_vector( const Scalar* lhs, int lhsStride, const Scalar* rhs, int rhsSize, ResType& res, Scalar alpha); +// This helper class aims to determine which optimized product to call, +// and how to call it. We have to distinghish three major cases: +// 1 - matrix-matrix +// 2 - matrix-vector +// 3 - vector-matrix +// The storage order, and direct-access criteria are also important for in last 2 cases. +// For instance, with a mat-vec product, the matrix coeff are evaluated only once, and +// therefore it is useless to first evaluated it to next being able to directly access +// its coefficient. template<typename ProductType, int LhsRows = ei_traits<ProductType>::RowsAtCompileTime, int LhsOrder = int(ei_traits<ProductType>::LhsFlags)&RowMajorBit ? RowMajor : ColMajor, |