diff options
author | 2009-01-14 17:41:55 +0000 | |
---|---|---|
committer | 2009-01-14 17:41:55 +0000 | |
commit | 0b606dcccd58fef640f5037088005dcdd1d3487e (patch) | |
tree | 58cadd4d9f22a517858c780ef0a12057b89e8636 /Eigen/src/Sparse | |
parent | c4c70669d165afefe0c68e7bb194ee81b9fba0b5 (diff) |
Add support for sparse * dense and dense * sparse matrix/vector products
Diffstat (limited to 'Eigen/src/Sparse')
-rw-r--r-- | Eigen/src/Sparse/SparseMatrix.h | 5 | ||||
-rw-r--r-- | Eigen/src/Sparse/SparseMatrixBase.h | 12 | ||||
-rw-r--r-- | Eigen/src/Sparse/SparseProduct.h | 130 | ||||
-rw-r--r-- | Eigen/src/Sparse/SparseUtil.h | 4 |
4 files changed, 114 insertions, 37 deletions
diff --git a/Eigen/src/Sparse/SparseMatrix.h b/Eigen/src/Sparse/SparseMatrix.h index a732bdc31..07fc0be8d 100644 --- a/Eigen/src/Sparse/SparseMatrix.h +++ b/Eigen/src/Sparse/SparseMatrix.h @@ -314,9 +314,10 @@ class SparseMatrix // 1 - compute the number of coeffs per dest inner vector // 2 - do the actual copy/eval // Since each coeff of the rhs has to be evaluated twice, let's evauluate it if needed - typedef typename ei_nested<OtherDerived,2>::type OtherCopy; - OtherCopy otherCopy(other.derived()); + //typedef typename ei_nested<OtherDerived,2>::type OtherCopy; + typedef typename ei_eval<OtherDerived>::type OtherCopy; typedef typename ei_cleantype<OtherCopy>::type _OtherCopy; + OtherCopy otherCopy(other.derived()); resize(other.rows(), other.cols()); Eigen::Map<VectorXi>(m_outerIndex,outerSize()).setZero(); diff --git a/Eigen/src/Sparse/SparseMatrixBase.h b/Eigen/src/Sparse/SparseMatrixBase.h index d01fa1ec5..14ac4e1cf 100644 --- a/Eigen/src/Sparse/SparseMatrixBase.h +++ b/Eigen/src/Sparse/SparseMatrixBase.h @@ -213,7 +213,7 @@ template<typename Derived> class SparseMatrixBase } template<typename Lhs, typename Rhs> - inline Derived& operator=(const SparseProduct<Lhs,Rhs>& product); + inline Derived& operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product); friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m) { @@ -291,6 +291,16 @@ template<typename Derived> class SparseMatrixBase template<typename OtherDerived> const typename SparseProductReturnType<Derived,OtherDerived>::Type operator*(const SparseMatrixBase<OtherDerived> &other) const; + + // dense * sparse (return a dense object) + template<typename OtherDerived> friend + const typename SparseProductReturnType<OtherDerived,Derived>::Type + operator*(const MatrixBase<OtherDerived>& lhs, const Derived& rhs) + { return typename SparseProductReturnType<OtherDerived,Derived>::Type(lhs.derived(),rhs); } + + template<typename OtherDerived> + const typename SparseProductReturnType<Derived,OtherDerived>::Type + operator*(const MatrixBase<OtherDerived> &other) const; template<typename OtherDerived> Derived& operator*=(const SparseMatrixBase<OtherDerived>& other); diff --git a/Eigen/src/Sparse/SparseProduct.h b/Eigen/src/Sparse/SparseProduct.h index b4ba2ee6f..29f5208fa 100644 --- a/Eigen/src/Sparse/SparseProduct.h +++ b/Eigen/src/Sparse/SparseProduct.h @@ -25,9 +25,29 @@ #ifndef EIGEN_SPARSEPRODUCT_H #define EIGEN_SPARSEPRODUCT_H +template<typename Lhs, typename Rhs> struct ei_sparse_product_mode +{ + enum { + + value = (Rhs::Flags&Lhs::Flags&SparseBit)==SparseBit + ? SparseTimeSparseProduct + : (Lhs::Flags&SparseBit)==SparseBit + ? SparseTimeDenseProduct + : DenseTimeSparseProduct }; +}; + +template<typename Lhs, typename Rhs, int ProductMode> +struct SparseProductReturnType +{ + typedef const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type LhsNested; + typedef const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type RhsNested; + + typedef SparseProduct<LhsNested, RhsNested, ProductMode> Type; +}; + // sparse product return type specialization template<typename Lhs, typename Rhs> -struct SparseProductReturnType +struct SparseProductReturnType<Lhs,Rhs,SparseTimeSparseProduct> { typedef typename ei_traits<Lhs>::Scalar Scalar; enum { @@ -47,11 +67,11 @@ struct SparseProductReturnType SparseMatrix<Scalar,0>, const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type>::ret RhsNested; - typedef SparseProduct<LhsNested, RhsNested> Type; + typedef SparseProduct<LhsNested, RhsNested, SparseTimeSparseProduct> Type; }; -template<typename LhsNested, typename RhsNested> -struct ei_traits<SparseProduct<LhsNested, RhsNested> > +template<typename LhsNested, typename RhsNested, int ProductMode> +struct ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> > { // clean the nested types: typedef typename ei_cleantype<LhsNested>::type _LhsNested; @@ -71,12 +91,13 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> > MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime, MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime, - LhsRowMajor = LhsFlags & RowMajorBit, - RhsRowMajor = RhsFlags & RowMajorBit, +// LhsIsRowMajor = (LhsFlags & RowMajorBit)==RowMajorBit, +// RhsIsRowMajor = (RhsFlags & RowMajorBit)==RowMajorBit, EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit), + ResultIsSparse = ProductMode==SparseTimeSparseProduct, - RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), + RemovedBits = ~( (EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSparse ? 0 : SparseBit) ), Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) | EvalBeforeAssigningBit @@ -84,11 +105,14 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> > CoeffReadCost = Dynamic }; + + typedef typename ei_meta_if<ResultIsSparse, + SparseMatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> >, + MatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> > >::ret Base; }; -template<typename LhsNested, typename RhsNested> -class SparseProduct : ei_no_assignment_operator, - public SparseMatrixBase<SparseProduct<LhsNested, RhsNested> > +template<typename LhsNested, typename RhsNested, int ProductMode> +class SparseProduct : ei_no_assignment_operator, public ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >::Base { public: @@ -102,17 +126,33 @@ class SparseProduct : ei_no_assignment_operator, public: template<typename Lhs, typename Rhs> - inline SparseProduct(const Lhs& lhs, const Rhs& rhs) + EIGEN_STRONG_INLINE SparseProduct(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) { ei_assert(lhs.cols() == rhs.rows()); + + enum { + ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic + || _RhsNested::RowsAtCompileTime==Dynamic + || int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime), + AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime, + SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested) + }; + // note to the lost user: + // * for a dot product use: v1.dot(v2) + // * for a coeff-wise product use: v1.cwise()*v2 + EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), + INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) + EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), + INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) + EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) } - inline int rows() const { return m_lhs.rows(); } - inline int cols() const { return m_rhs.cols(); } + EIGEN_STRONG_INLINE int rows() const { return m_lhs.rows(); } + EIGEN_STRONG_INLINE int cols() const { return m_rhs.cols(); } - const _LhsNested& lhs() const { return m_lhs; } - const _LhsNested& rhs() const { return m_rhs; } + EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } + EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } protected: LhsNested m_lhs; @@ -240,9 +280,10 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> // return derived(); // } +// sparse = sparse * sparse template<typename Derived> template<typename Lhs, typename Rhs> -inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs>& product) +inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product) { // std::cout << "sparse product to sparse\n"; ei_sparse_product_selector< @@ -252,26 +293,51 @@ inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs return derived(); } +// dense = sparse * dense +template<typename Derived> +template<typename Lhs, typename Rhs> +Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,SparseTimeDenseProduct>& product) +{ + typedef typename ei_cleantype<Lhs>::type _Lhs; + typedef typename _Lhs::InnerIterator LhsInnerIterator; + enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit }; + derived().setZero(); + for (int j=0; j<product.lhs().outerSize(); ++j) + for (LhsInnerIterator i(product.lhs(),j); i; ++i) + derived().row(LhsIsRowMajor ? j : i.index()) += i.value() * product.rhs().row(LhsIsRowMajor ? i.index() : j); + return derived(); +} + +// dense = dense * sparse +template<typename Derived> +template<typename Lhs, typename Rhs> +Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,DenseTimeSparseProduct>& product) +{ + typedef typename ei_cleantype<Rhs>::type _Rhs; + typedef typename _Rhs::InnerIterator RhsInnerIterator; + enum { RhsIsRowMajor = (_Rhs::Flags&RowMajorBit)==RowMajorBit }; + derived().setZero(); + for (int j=0; j<product.rhs().outerSize(); ++j) + for (RhsInnerIterator i(product.rhs(),j); i; ++i) + derived().col(RhsIsRowMajor ? i.index() : j) += i.value() * product.lhs().col(RhsIsRowMajor ? j : i.index()); + return derived(); +} + +// sparse * sparse template<typename Derived> template<typename OtherDerived> -inline const typename SparseProductReturnType<Derived,OtherDerived>::Type +EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const { - enum { - ProductIsValid = Derived::ColsAtCompileTime==Dynamic - || OtherDerived::RowsAtCompileTime==Dynamic - || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime), - AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime, - SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived) - }; - // note to the lost user: - // * for a dot product use: v1.dot(v2) - // * for a coeff-wise product use: v1.cwise()*v2 - EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), - INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) - EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), - INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) - EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) + return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); +} + +// sparse * dense +template<typename Derived> +template<typename OtherDerived> +EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type +SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const +{ return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); } diff --git a/Eigen/src/Sparse/SparseUtil.h b/Eigen/src/Sparse/SparseUtil.h index 724fb9efb..046523d8f 100644 --- a/Eigen/src/Sparse/SparseUtil.h +++ b/Eigen/src/Sparse/SparseUtil.h @@ -109,10 +109,10 @@ template<typename MatrixType> class SparseInnerVector; template<typename Derived> class SparseCwise; template<typename UnaryOp, typename MatrixType> class SparseCwiseUnaryOp; template<typename BinaryOp, typename Lhs, typename Rhs> class SparseCwiseBinaryOp; -template<typename Lhs, typename Rhs> class SparseProduct; template<typename ExpressionType, unsigned int Added, unsigned int Removed> class SparseFlagged; -template<typename Lhs, typename Rhs> struct SparseProductReturnType; +template<typename Lhs, typename Rhs> struct ei_sparse_product_mode; +template<typename Lhs, typename Rhs, int ProductMode = ei_sparse_product_mode<Lhs,Rhs>::value> struct SparseProductReturnType; const int AccessPatternNotSupported = 0x0; const int AccessPatternSupported = 0x1; |