diff options
Diffstat (limited to 'Eigen/src/Sparse/SparseProduct.h')
-rw-r--r-- | Eigen/src/Sparse/SparseProduct.h | 130 |
1 files changed, 98 insertions, 32 deletions
diff --git a/Eigen/src/Sparse/SparseProduct.h b/Eigen/src/Sparse/SparseProduct.h index b4ba2ee6f..29f5208fa 100644 --- a/Eigen/src/Sparse/SparseProduct.h +++ b/Eigen/src/Sparse/SparseProduct.h @@ -25,9 +25,29 @@ #ifndef EIGEN_SPARSEPRODUCT_H #define EIGEN_SPARSEPRODUCT_H +template<typename Lhs, typename Rhs> struct ei_sparse_product_mode +{ + enum { + + value = (Rhs::Flags&Lhs::Flags&SparseBit)==SparseBit + ? SparseTimeSparseProduct + : (Lhs::Flags&SparseBit)==SparseBit + ? SparseTimeDenseProduct + : DenseTimeSparseProduct }; +}; + +template<typename Lhs, typename Rhs, int ProductMode> +struct SparseProductReturnType +{ + typedef const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type LhsNested; + typedef const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type RhsNested; + + typedef SparseProduct<LhsNested, RhsNested, ProductMode> Type; +}; + // sparse product return type specialization template<typename Lhs, typename Rhs> -struct SparseProductReturnType +struct SparseProductReturnType<Lhs,Rhs,SparseTimeSparseProduct> { typedef typename ei_traits<Lhs>::Scalar Scalar; enum { @@ -47,11 +67,11 @@ struct SparseProductReturnType SparseMatrix<Scalar,0>, const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type>::ret RhsNested; - typedef SparseProduct<LhsNested, RhsNested> Type; + typedef SparseProduct<LhsNested, RhsNested, SparseTimeSparseProduct> Type; }; -template<typename LhsNested, typename RhsNested> -struct ei_traits<SparseProduct<LhsNested, RhsNested> > +template<typename LhsNested, typename RhsNested, int ProductMode> +struct ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> > { // clean the nested types: typedef typename ei_cleantype<LhsNested>::type _LhsNested; @@ -71,12 +91,13 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> > MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime, MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime, - LhsRowMajor = LhsFlags & RowMajorBit, - RhsRowMajor = RhsFlags & RowMajorBit, +// LhsIsRowMajor = (LhsFlags & RowMajorBit)==RowMajorBit, +// RhsIsRowMajor = (RhsFlags & RowMajorBit)==RowMajorBit, EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit), + ResultIsSparse = ProductMode==SparseTimeSparseProduct, - RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), + RemovedBits = ~( (EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSparse ? 0 : SparseBit) ), Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) | EvalBeforeAssigningBit @@ -84,11 +105,14 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> > CoeffReadCost = Dynamic }; + + typedef typename ei_meta_if<ResultIsSparse, + SparseMatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> >, + MatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> > >::ret Base; }; -template<typename LhsNested, typename RhsNested> -class SparseProduct : ei_no_assignment_operator, - public SparseMatrixBase<SparseProduct<LhsNested, RhsNested> > +template<typename LhsNested, typename RhsNested, int ProductMode> +class SparseProduct : ei_no_assignment_operator, public ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >::Base { public: @@ -102,17 +126,33 @@ class SparseProduct : ei_no_assignment_operator, public: template<typename Lhs, typename Rhs> - inline SparseProduct(const Lhs& lhs, const Rhs& rhs) + EIGEN_STRONG_INLINE SparseProduct(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) { ei_assert(lhs.cols() == rhs.rows()); + + enum { + ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic + || _RhsNested::RowsAtCompileTime==Dynamic + || int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime), + AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime, + SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested) + }; + // note to the lost user: + // * for a dot product use: v1.dot(v2) + // * for a coeff-wise product use: v1.cwise()*v2 + EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), + INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) + EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), + INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) + EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) } - inline int rows() const { return m_lhs.rows(); } - inline int cols() const { return m_rhs.cols(); } + EIGEN_STRONG_INLINE int rows() const { return m_lhs.rows(); } + EIGEN_STRONG_INLINE int cols() const { return m_rhs.cols(); } - const _LhsNested& lhs() const { return m_lhs; } - const _LhsNested& rhs() const { return m_rhs; } + EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } + EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } protected: LhsNested m_lhs; @@ -240,9 +280,10 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> // return derived(); // } +// sparse = sparse * sparse template<typename Derived> template<typename Lhs, typename Rhs> -inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs>& product) +inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product) { // std::cout << "sparse product to sparse\n"; ei_sparse_product_selector< @@ -252,26 +293,51 @@ inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs return derived(); } +// dense = sparse * dense +template<typename Derived> +template<typename Lhs, typename Rhs> +Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,SparseTimeDenseProduct>& product) +{ + typedef typename ei_cleantype<Lhs>::type _Lhs; + typedef typename _Lhs::InnerIterator LhsInnerIterator; + enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit }; + derived().setZero(); + for (int j=0; j<product.lhs().outerSize(); ++j) + for (LhsInnerIterator i(product.lhs(),j); i; ++i) + derived().row(LhsIsRowMajor ? j : i.index()) += i.value() * product.rhs().row(LhsIsRowMajor ? i.index() : j); + return derived(); +} + +// dense = dense * sparse +template<typename Derived> +template<typename Lhs, typename Rhs> +Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,DenseTimeSparseProduct>& product) +{ + typedef typename ei_cleantype<Rhs>::type _Rhs; + typedef typename _Rhs::InnerIterator RhsInnerIterator; + enum { RhsIsRowMajor = (_Rhs::Flags&RowMajorBit)==RowMajorBit }; + derived().setZero(); + for (int j=0; j<product.rhs().outerSize(); ++j) + for (RhsInnerIterator i(product.rhs(),j); i; ++i) + derived().col(RhsIsRowMajor ? i.index() : j) += i.value() * product.lhs().col(RhsIsRowMajor ? j : i.index()); + return derived(); +} + +// sparse * sparse template<typename Derived> template<typename OtherDerived> -inline const typename SparseProductReturnType<Derived,OtherDerived>::Type +EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const { - enum { - ProductIsValid = Derived::ColsAtCompileTime==Dynamic - || OtherDerived::RowsAtCompileTime==Dynamic - || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime), - AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime, - SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived) - }; - // note to the lost user: - // * for a dot product use: v1.dot(v2) - // * for a coeff-wise product use: v1.cwise()*v2 - EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), - INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) - EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), - INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) - EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) + return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); +} + +// sparse * dense +template<typename Derived> +template<typename OtherDerived> +EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type +SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const +{ return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); } |