aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Sparse/SparseProduct.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2009-01-14 17:41:55 +0000
committerGravatar Gael Guennebaud <g.gael@free.fr>2009-01-14 17:41:55 +0000
commit0b606dcccd58fef640f5037088005dcdd1d3487e (patch)
tree58cadd4d9f22a517858c780ef0a12057b89e8636 /Eigen/src/Sparse/SparseProduct.h
parentc4c70669d165afefe0c68e7bb194ee81b9fba0b5 (diff)
Add support for sparse * dense and dense * sparse matrix/vector products
Diffstat (limited to 'Eigen/src/Sparse/SparseProduct.h')
-rw-r--r--Eigen/src/Sparse/SparseProduct.h130
1 files changed, 98 insertions, 32 deletions
diff --git a/Eigen/src/Sparse/SparseProduct.h b/Eigen/src/Sparse/SparseProduct.h
index b4ba2ee6f..29f5208fa 100644
--- a/Eigen/src/Sparse/SparseProduct.h
+++ b/Eigen/src/Sparse/SparseProduct.h
@@ -25,9 +25,29 @@
#ifndef EIGEN_SPARSEPRODUCT_H
#define EIGEN_SPARSEPRODUCT_H
+template<typename Lhs, typename Rhs> struct ei_sparse_product_mode
+{
+ enum {
+
+ value = (Rhs::Flags&Lhs::Flags&SparseBit)==SparseBit
+ ? SparseTimeSparseProduct
+ : (Lhs::Flags&SparseBit)==SparseBit
+ ? SparseTimeDenseProduct
+ : DenseTimeSparseProduct };
+};
+
+template<typename Lhs, typename Rhs, int ProductMode>
+struct SparseProductReturnType
+{
+ typedef const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type LhsNested;
+ typedef const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
+
+ typedef SparseProduct<LhsNested, RhsNested, ProductMode> Type;
+};
+
// sparse product return type specialization
template<typename Lhs, typename Rhs>
-struct SparseProductReturnType
+struct SparseProductReturnType<Lhs,Rhs,SparseTimeSparseProduct>
{
typedef typename ei_traits<Lhs>::Scalar Scalar;
enum {
@@ -47,11 +67,11 @@ struct SparseProductReturnType
SparseMatrix<Scalar,0>,
const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type>::ret RhsNested;
- typedef SparseProduct<LhsNested, RhsNested> Type;
+ typedef SparseProduct<LhsNested, RhsNested, SparseTimeSparseProduct> Type;
};
-template<typename LhsNested, typename RhsNested>
-struct ei_traits<SparseProduct<LhsNested, RhsNested> >
+template<typename LhsNested, typename RhsNested, int ProductMode>
+struct ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >
{
// clean the nested types:
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
@@ -71,12 +91,13 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> >
MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime,
MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime,
- LhsRowMajor = LhsFlags & RowMajorBit,
- RhsRowMajor = RhsFlags & RowMajorBit,
+// LhsIsRowMajor = (LhsFlags & RowMajorBit)==RowMajorBit,
+// RhsIsRowMajor = (RhsFlags & RowMajorBit)==RowMajorBit,
EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit),
+ ResultIsSparse = ProductMode==SparseTimeSparseProduct,
- RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
+ RemovedBits = ~( (EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSparse ? 0 : SparseBit) ),
Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
| EvalBeforeAssigningBit
@@ -84,11 +105,14 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> >
CoeffReadCost = Dynamic
};
+
+ typedef typename ei_meta_if<ResultIsSparse,
+ SparseMatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> >,
+ MatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> > >::ret Base;
};
-template<typename LhsNested, typename RhsNested>
-class SparseProduct : ei_no_assignment_operator,
- public SparseMatrixBase<SparseProduct<LhsNested, RhsNested> >
+template<typename LhsNested, typename RhsNested, int ProductMode>
+class SparseProduct : ei_no_assignment_operator, public ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >::Base
{
public:
@@ -102,17 +126,33 @@ class SparseProduct : ei_no_assignment_operator,
public:
template<typename Lhs, typename Rhs>
- inline SparseProduct(const Lhs& lhs, const Rhs& rhs)
+ EIGEN_STRONG_INLINE SparseProduct(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{
ei_assert(lhs.cols() == rhs.rows());
+
+ enum {
+ ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic
+ || _RhsNested::RowsAtCompileTime==Dynamic
+ || int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime),
+ AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime,
+ SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested)
+ };
+ // note to the lost user:
+ // * for a dot product use: v1.dot(v2)
+ // * for a coeff-wise product use: v1.cwise()*v2
+ EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
+ INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
+ EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
+ INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
+ EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
}
- inline int rows() const { return m_lhs.rows(); }
- inline int cols() const { return m_rhs.cols(); }
+ EIGEN_STRONG_INLINE int rows() const { return m_lhs.rows(); }
+ EIGEN_STRONG_INLINE int cols() const { return m_rhs.cols(); }
- const _LhsNested& lhs() const { return m_lhs; }
- const _LhsNested& rhs() const { return m_rhs; }
+ EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
+ EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
protected:
LhsNested m_lhs;
@@ -240,9 +280,10 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
// return derived();
// }
+// sparse = sparse * sparse
template<typename Derived>
template<typename Lhs, typename Rhs>
-inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs>& product)
+inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product)
{
// std::cout << "sparse product to sparse\n";
ei_sparse_product_selector<
@@ -252,26 +293,51 @@ inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs
return derived();
}
+// dense = sparse * dense
+template<typename Derived>
+template<typename Lhs, typename Rhs>
+Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,SparseTimeDenseProduct>& product)
+{
+ typedef typename ei_cleantype<Lhs>::type _Lhs;
+ typedef typename _Lhs::InnerIterator LhsInnerIterator;
+ enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit };
+ derived().setZero();
+ for (int j=0; j<product.lhs().outerSize(); ++j)
+ for (LhsInnerIterator i(product.lhs(),j); i; ++i)
+ derived().row(LhsIsRowMajor ? j : i.index()) += i.value() * product.rhs().row(LhsIsRowMajor ? i.index() : j);
+ return derived();
+}
+
+// dense = dense * sparse
+template<typename Derived>
+template<typename Lhs, typename Rhs>
+Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,DenseTimeSparseProduct>& product)
+{
+ typedef typename ei_cleantype<Rhs>::type _Rhs;
+ typedef typename _Rhs::InnerIterator RhsInnerIterator;
+ enum { RhsIsRowMajor = (_Rhs::Flags&RowMajorBit)==RowMajorBit };
+ derived().setZero();
+ for (int j=0; j<product.rhs().outerSize(); ++j)
+ for (RhsInnerIterator i(product.rhs(),j); i; ++i)
+ derived().col(RhsIsRowMajor ? i.index() : j) += i.value() * product.lhs().col(RhsIsRowMajor ? j : i.index());
+ return derived();
+}
+
+// sparse * sparse
template<typename Derived>
template<typename OtherDerived>
-inline const typename SparseProductReturnType<Derived,OtherDerived>::Type
+EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
{
- enum {
- ProductIsValid = Derived::ColsAtCompileTime==Dynamic
- || OtherDerived::RowsAtCompileTime==Dynamic
- || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
- AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
- SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
- };
- // note to the lost user:
- // * for a dot product use: v1.dot(v2)
- // * for a coeff-wise product use: v1.cwise()*v2
- EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
- INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
- EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
- INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
- EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
+ return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
+}
+
+// sparse * dense
+template<typename Derived>
+template<typename OtherDerived>
+EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
+SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
+{
return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
}