diff options
author | Gael Guennebaud <g.gael@free.fr> | 2012-03-01 10:13:13 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2012-03-01 10:13:13 +0100 |
commit | 553a0ae924d8307944ab8465bbe2106c449ea2d7 (patch) | |
tree | cae7c8430b8391238b328cd28ee622436635968d /Eigen | |
parent | 85b358097d50bb2f3c95fb41fa6879faa533ab0d (diff) |
simplify and speedup sparse * dense matrix products
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/src/SparseCore/SparseDenseProduct.h | 127 |
1 files changed, 101 insertions, 26 deletions
diff --git a/Eigen/src/SparseCore/SparseDenseProduct.h b/Eigen/src/SparseCore/SparseDenseProduct.h index b372853dc..0bdaee21e 100644 --- a/Eigen/src/SparseCore/SparseDenseProduct.h +++ b/Eigen/src/SparseCore/SparseDenseProduct.h @@ -149,6 +149,102 @@ struct traits<SparseTimeDenseProduct<Lhs,Rhs> > typedef Dense StorageKind; typedef MatrixXpr XprKind; }; + +template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, + int LhsStorageOrder = SparseLhsType::IsRowMajor?RowMajor:ColMajor, + bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1> +struct sparse_time_dense_product_impl; + +template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> +struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true> +{ + typedef typename internal::remove_all<SparseLhsType>::type Lhs; + typedef typename internal::remove_all<DenseRhsType>::type Rhs; + typedef typename internal::remove_all<DenseResType>::type Res; + typedef typename Lhs::Index Index; + typedef typename Lhs::InnerIterator LhsInnerIterator; + static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) + { + for(Index c=0; c<rhs.cols(); ++c) + { + Index j=0; + for(j=0; j<lhs.outerSize(); ++j) + { + typename Res::Scalar tmp(0); + for(LhsInnerIterator it(lhs,j); it ;++it) + tmp += it.value() * rhs.coeff(it.index(),c); + res.coeffRef(j,c) = alpha * tmp; + } + } + } +}; + +template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> +struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true> +{ + typedef typename internal::remove_all<SparseLhsType>::type Lhs; + typedef typename internal::remove_all<DenseRhsType>::type Rhs; + typedef typename internal::remove_all<DenseResType>::type Res; + typedef typename Lhs::InnerIterator LhsInnerIterator; + typedef typename Lhs::Index Index; + static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) + { + for(Index c=0; c<rhs.cols(); ++c) + { + for(Index j=0; j<lhs.outerSize(); ++j) + { + typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); + for(LhsInnerIterator it(lhs,j); it ;++it) + res.coeffRef(it.index(),c) += it.value() * rhs_j; + } + } + } +}; + +template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> +struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false> +{ + typedef typename internal::remove_all<SparseLhsType>::type Lhs; + typedef typename internal::remove_all<DenseRhsType>::type Rhs; + typedef typename internal::remove_all<DenseResType>::type Res; + typedef typename Lhs::InnerIterator LhsInnerIterator; + typedef typename Lhs::Index Index; + static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) + { + for(Index j=0; j<lhs.outerSize(); ++j) + { + typename Res::RowXpr res_j(res.row(j)); + for(LhsInnerIterator it(lhs,j); it ;++it) + res_j += (alpha*it.value()) * rhs.row(it.index()); + } + } +}; + +template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> +struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false> +{ + typedef typename internal::remove_all<SparseLhsType>::type Lhs; + typedef typename internal::remove_all<DenseRhsType>::type Rhs; + typedef typename internal::remove_all<DenseResType>::type Res; + typedef typename Lhs::InnerIterator LhsInnerIterator; + typedef typename Lhs::Index Index; + static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) + { + for(Index j=0; j<lhs.outerSize(); ++j) + { + typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); + for(LhsInnerIterator it(lhs,j); it ;++it) + res.row(it.index()) += (alpha*it.value()) * rhs_j; + } + } +}; + +template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType> +inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) +{ + sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha); +} + } // end namespace internal template<typename Lhs, typename Rhs> @@ -163,27 +259,7 @@ class SparseTimeDenseProduct template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const { - typedef typename _LhsNested::InnerIterator LhsInnerIterator; - enum { - LhsIsRowMajor = (_LhsNested::Flags&RowMajorBit)==RowMajorBit, - RhsIsVector = _RhsNested::ColsAtCompileTime==1 - }; - Index j=0; - for(j=0; j<m_lhs.outerSize(); ++j) - { - typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(LhsIsRowMajor ? 0 : j,0); - typename Dest::RowXpr dest_j(dest.row(LhsIsRowMajor ? j : 0)); - typename Dest::Scalar tmp(0); - for(LhsInnerIterator it(m_lhs,j); it ;++it) - { - if(LhsIsRowMajor && RhsIsVector) tmp += (it.value()) * m_rhs.coeff(it.index()); - else if(LhsIsRowMajor) dest_j += (alpha*it.value()) * m_rhs.row(it.index()); - else if(RhsIsVector) dest.coeffRef(it.index()) += it.value() * rhs_j; - else dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j); - } - if(LhsIsRowMajor && RhsIsVector) - dest.coeffRef(LhsIsRowMajor ? j : 0) = alpha * tmp; - } + internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha); } private: @@ -213,11 +289,10 @@ class DenseTimeSparseProduct template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const { - typedef typename _RhsNested::InnerIterator RhsInnerIterator; - enum { RhsIsRowMajor = (_RhsNested::Flags&RowMajorBit)==RowMajorBit }; - for(Index j=0; j<m_rhs.outerSize(); ++j) - for(RhsInnerIterator i(m_rhs,j); i; ++i) - dest.col(RhsIsRowMajor ? i.index() : j) += (alpha*i.value()) * m_lhs.col(RhsIsRowMajor ? j : i.index()); + Transpose<const _LhsNested> lhs_t(m_lhs); + Transpose<const _RhsNested> rhs_t(m_rhs); + Transpose<Dest> dest_t(dest); + internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha); } private: |