aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2012-03-01 10:13:13 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2012-03-01 10:13:13 +0100
commit553a0ae924d8307944ab8465bbe2106c449ea2d7 (patch)
treecae7c8430b8391238b328cd28ee622436635968d /Eigen
parent85b358097d50bb2f3c95fb41fa6879faa533ab0d (diff)
simplify and speedup sparse * dense matrix products
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/SparseCore/SparseDenseProduct.h127
1 files changed, 101 insertions, 26 deletions
diff --git a/Eigen/src/SparseCore/SparseDenseProduct.h b/Eigen/src/SparseCore/SparseDenseProduct.h
index b372853dc..0bdaee21e 100644
--- a/Eigen/src/SparseCore/SparseDenseProduct.h
+++ b/Eigen/src/SparseCore/SparseDenseProduct.h
@@ -149,6 +149,102 @@ struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
typedef Dense StorageKind;
typedef MatrixXpr XprKind;
};
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
+ int LhsStorageOrder = SparseLhsType::IsRowMajor?RowMajor:ColMajor,
+ bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
+struct sparse_time_dense_product_impl;
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
+struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
+{
+ typedef typename internal::remove_all<SparseLhsType>::type Lhs;
+ typedef typename internal::remove_all<DenseRhsType>::type Rhs;
+ typedef typename internal::remove_all<DenseResType>::type Res;
+ typedef typename Lhs::Index Index;
+ typedef typename Lhs::InnerIterator LhsInnerIterator;
+ static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
+ {
+ for(Index c=0; c<rhs.cols(); ++c)
+ {
+ Index j=0;
+ for(j=0; j<lhs.outerSize(); ++j)
+ {
+ typename Res::Scalar tmp(0);
+ for(LhsInnerIterator it(lhs,j); it ;++it)
+ tmp += it.value() * rhs.coeff(it.index(),c);
+ res.coeffRef(j,c) = alpha * tmp;
+ }
+ }
+ }
+};
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
+struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
+{
+ typedef typename internal::remove_all<SparseLhsType>::type Lhs;
+ typedef typename internal::remove_all<DenseRhsType>::type Rhs;
+ typedef typename internal::remove_all<DenseResType>::type Res;
+ typedef typename Lhs::InnerIterator LhsInnerIterator;
+ typedef typename Lhs::Index Index;
+ static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
+ {
+ for(Index c=0; c<rhs.cols(); ++c)
+ {
+ for(Index j=0; j<lhs.outerSize(); ++j)
+ {
+ typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
+ for(LhsInnerIterator it(lhs,j); it ;++it)
+ res.coeffRef(it.index(),c) += it.value() * rhs_j;
+ }
+ }
+ }
+};
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
+struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
+{
+ typedef typename internal::remove_all<SparseLhsType>::type Lhs;
+ typedef typename internal::remove_all<DenseRhsType>::type Rhs;
+ typedef typename internal::remove_all<DenseResType>::type Res;
+ typedef typename Lhs::InnerIterator LhsInnerIterator;
+ typedef typename Lhs::Index Index;
+ static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
+ {
+ for(Index j=0; j<lhs.outerSize(); ++j)
+ {
+ typename Res::RowXpr res_j(res.row(j));
+ for(LhsInnerIterator it(lhs,j); it ;++it)
+ res_j += (alpha*it.value()) * rhs.row(it.index());
+ }
+ }
+};
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
+struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
+{
+ typedef typename internal::remove_all<SparseLhsType>::type Lhs;
+ typedef typename internal::remove_all<DenseRhsType>::type Rhs;
+ typedef typename internal::remove_all<DenseResType>::type Res;
+ typedef typename Lhs::InnerIterator LhsInnerIterator;
+ typedef typename Lhs::Index Index;
+ static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
+ {
+ for(Index j=0; j<lhs.outerSize(); ++j)
+ {
+ typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
+ for(LhsInnerIterator it(lhs,j); it ;++it)
+ res.row(it.index()) += (alpha*it.value()) * rhs_j;
+ }
+ }
+};
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
+inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
+{
+ sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
+}
+
} // end namespace internal
template<typename Lhs, typename Rhs>
@@ -163,27 +259,7 @@ class SparseTimeDenseProduct
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
{
- typedef typename _LhsNested::InnerIterator LhsInnerIterator;
- enum {
- LhsIsRowMajor = (_LhsNested::Flags&RowMajorBit)==RowMajorBit,
- RhsIsVector = _RhsNested::ColsAtCompileTime==1
- };
- Index j=0;
- for(j=0; j<m_lhs.outerSize(); ++j)
- {
- typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(LhsIsRowMajor ? 0 : j,0);
- typename Dest::RowXpr dest_j(dest.row(LhsIsRowMajor ? j : 0));
- typename Dest::Scalar tmp(0);
- for(LhsInnerIterator it(m_lhs,j); it ;++it)
- {
- if(LhsIsRowMajor && RhsIsVector) tmp += (it.value()) * m_rhs.coeff(it.index());
- else if(LhsIsRowMajor) dest_j += (alpha*it.value()) * m_rhs.row(it.index());
- else if(RhsIsVector) dest.coeffRef(it.index()) += it.value() * rhs_j;
- else dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j);
- }
- if(LhsIsRowMajor && RhsIsVector)
- dest.coeffRef(LhsIsRowMajor ? j : 0) = alpha * tmp;
- }
+ internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
}
private:
@@ -213,11 +289,10 @@ class DenseTimeSparseProduct
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
{
- typedef typename _RhsNested::InnerIterator RhsInnerIterator;
- enum { RhsIsRowMajor = (_RhsNested::Flags&RowMajorBit)==RowMajorBit };
- for(Index j=0; j<m_rhs.outerSize(); ++j)
- for(RhsInnerIterator i(m_rhs,j); i; ++i)
- dest.col(RhsIsRowMajor ? i.index() : j) += (alpha*i.value()) * m_lhs.col(RhsIsRowMajor ? j : i.index());
+ Transpose<const _LhsNested> lhs_t(m_lhs);
+ Transpose<const _RhsNested> rhs_t(m_rhs);
+ Transpose<Dest> dest_t(dest);
+ internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
}
private: