aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2015-10-26 18:20:00 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2015-10-26 18:20:00 +0100
commite6f8c5c325fca53b53436b6bd8d66749444216bb (patch)
tree2ae45f2b23242e6c92ce2d2e2cef3c9937840957
parenta5324a131f3816c8312e27a9dc928b8d56d8cf3b (diff)
Add support to directly evaluate the product of two sparse matrices within a dense matrix.
-rw-r--r--Eigen/src/SparseCore/ConservativeSparseSparseProduct.h85
-rw-r--r--Eigen/src/SparseCore/SparseAssign.h8
-rw-r--r--Eigen/src/SparseCore/SparseMatrixBase.h2
-rw-r--r--Eigen/src/SparseCore/SparseProduct.h34
-rw-r--r--test/sparse_product.cpp11
5 files changed, 132 insertions, 8 deletions
diff --git a/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h b/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h
index a61ceb7cc..0f6835846 100644
--- a/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h
+++ b/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h
@@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
-// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -257,6 +257,89 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,R
} // end namespace internal
+
+namespace internal {
+
+template<typename Lhs, typename Rhs, typename ResultType>
+static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
+{
+ typedef typename remove_all<Lhs>::type::Scalar Scalar;
+ Index cols = rhs.outerSize();
+ eigen_assert(lhs.outerSize() == rhs.innerSize());
+
+ evaluator<Lhs> lhsEval(lhs);
+ evaluator<Rhs> rhsEval(rhs);
+
+ for (Index j=0; j<cols; ++j)
+ {
+ for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
+ {
+ Scalar y = rhsIt.value();
+ Index k = rhsIt.index();
+ for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
+ {
+ Index i = lhsIt.index();
+ Scalar x = lhsIt.value();
+ res.coeffRef(i,j) += x * y;
+ }
+ }
+ }
+}
+
+
+} // end namespace internal
+
+namespace internal {
+
+template<typename Lhs, typename Rhs, typename ResultType,
+ int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
+ int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor>
+struct sparse_sparse_to_dense_product_selector;
+
+template<typename Lhs, typename Rhs, typename ResultType>
+struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor>
+{
+ static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
+ {
+ internal::sparse_sparse_to_dense_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, res);
+ }
+};
+
+template<typename Lhs, typename Rhs, typename ResultType>
+struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor>
+{
+ static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
+ {
+ typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
+ ColMajorMatrix lhsCol(lhs);
+ internal::sparse_sparse_to_dense_product_impl<ColMajorMatrix,Rhs,ResultType>(lhsCol, rhs, res);
+ }
+};
+
+template<typename Lhs, typename Rhs, typename ResultType>
+struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor>
+{
+ static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
+ {
+ typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
+ ColMajorMatrix rhsCol(rhs);
+ internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorMatrix,ResultType>(lhs, rhsCol, res);
+ }
+};
+
+template<typename Lhs, typename Rhs, typename ResultType>
+struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor>
+{
+ static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
+ {
+ Transpose<ResultType> trRes(res);
+ internal::sparse_sparse_to_dense_product_impl<Rhs,Lhs,Transpose<ResultType> >(rhs, lhs, trRes);
+ }
+};
+
+
+} // end namespace internal
+
} // end namespace Eigen
#endif // EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
diff --git a/Eigen/src/SparseCore/SparseAssign.h b/Eigen/src/SparseCore/SparseAssign.h
index e984bbdb3..4b663a59e 100644
--- a/Eigen/src/SparseCore/SparseAssign.h
+++ b/Eigen/src/SparseCore/SparseAssign.h
@@ -133,8 +133,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Sparse, Scalar>
};
// Sparse to Dense assignment
-template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
-struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar>
+template< typename DstXprType, typename SrcXprType, typename Functor>
+struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense>
{
static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
{
@@ -149,8 +149,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar>
}
};
-template< typename DstXprType, typename SrcXprType, typename Scalar>
-struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense, Scalar>
+template< typename DstXprType, typename SrcXprType>
+struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense>
{
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &)
{
diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h
index 4e720904e..38eb1c37a 100644
--- a/Eigen/src/SparseCore/SparseMatrixBase.h
+++ b/Eigen/src/SparseCore/SparseMatrixBase.h
@@ -281,7 +281,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
// sparse * sparse
template<typename OtherDerived>
- const Product<Derived,OtherDerived>
+ const Product<Derived,OtherDerived,AliasFreeProduct>
operator*(const SparseMatrixBase<OtherDerived> &other) const;
// sparse * dense
diff --git a/Eigen/src/SparseCore/SparseProduct.h b/Eigen/src/SparseCore/SparseProduct.h
index da8919ecc..26680b7a7 100644
--- a/Eigen/src/SparseCore/SparseProduct.h
+++ b/Eigen/src/SparseCore/SparseProduct.h
@@ -25,10 +25,10 @@ namespace Eigen {
* */
template<typename Derived>
template<typename OtherDerived>
-inline const Product<Derived,OtherDerived>
+inline const Product<Derived,OtherDerived,AliasFreeProduct>
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
{
- return Product<Derived,OtherDerived>(derived(), other.derived());
+ return Product<Derived,OtherDerived,AliasFreeProduct>(derived(), other.derived());
}
namespace internal {
@@ -61,6 +61,36 @@ struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, SparseShape, Produc
: public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
{};
+// Dense = sparse * sparse
+template< typename DstXprType, typename Lhs, typename Rhs, int Options/*, typename Scalar*/>
+struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense/*,
+ typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/>
+{
+ typedef Product<Lhs,Rhs,Options> SrcXprType;
+ static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &)
+ {
+ dst.setZero();
+ dst += src;
+ }
+};
+
+// Dense += sparse * sparse
+template< typename DstXprType, typename Lhs, typename Rhs, int Options>
+struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::add_assign_op<typename DstXprType::Scalar>, Sparse2Dense/*,
+ typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/>
+{
+ typedef Product<Lhs,Rhs,Options> SrcXprType;
+ static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<typename DstXprType::Scalar> &)
+ {
+ typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
+ typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
+ LhsNested lhsNested(src.lhs());
+ RhsNested rhsNested(src.rhs());
+ internal::sparse_sparse_to_dense_product_selector<typename remove_all<LhsNested>::type,
+ typename remove_all<RhsNested>::type, DstXprType>::run(lhsNested,rhsNested,dst);
+ }
+};
+
template<typename Lhs, typename Rhs, int Options>
struct evaluator<SparseView<Product<Lhs, Rhs, Options> > >
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>
diff --git a/test/sparse_product.cpp b/test/sparse_product.cpp
index f1e5b8e4c..8c83f08d7 100644
--- a/test/sparse_product.cpp
+++ b/test/sparse_product.cpp
@@ -76,6 +76,17 @@ template<typename SparseMatrixType> void sparse_product()
VERIFY_IS_APPROX(m4=(m2t.transpose()*m3t.transpose()).pruned(0), refMat4=refMat2t.transpose()*refMat3t.transpose());
VERIFY_IS_APPROX(m4=(m2*m3t.transpose()).pruned(0), refMat4=refMat2*refMat3t.transpose());
+ // dense ?= sparse * sparse
+ VERIFY_IS_APPROX(dm4 =m2*m3, refMat4 =refMat2*refMat3);
+ VERIFY_IS_APPROX(dm4+=m2*m3, refMat4+=refMat2*refMat3);
+ VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3, refMat4 =refMat2t.transpose()*refMat3);
+ VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3, refMat4+=refMat2t.transpose()*refMat3);
+ VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3t.transpose(), refMat4 =refMat2t.transpose()*refMat3t.transpose());
+ VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3t.transpose(), refMat4+=refMat2t.transpose()*refMat3t.transpose());
+ VERIFY_IS_APPROX(dm4 =m2*m3t.transpose(), refMat4 =refMat2*refMat3t.transpose());
+ VERIFY_IS_APPROX(dm4+=m2*m3t.transpose(), refMat4+=refMat2*refMat3t.transpose());
+ VERIFY_IS_APPROX(dm4 = m2*m3*s1, refMat4 = refMat2*refMat3*s1);
+
// test aliasing
m4 = m2; refMat4 = refMat2;
VERIFY_IS_APPROX(m4=m4*m3, refMat4=refMat4*refMat3);