diff options
author | Gael Guennebaud <g.gael@free.fr> | 2015-10-26 18:20:00 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2015-10-26 18:20:00 +0100 |
commit | e6f8c5c325fca53b53436b6bd8d66749444216bb (patch) | |
tree | 2ae45f2b23242e6c92ce2d2e2cef3c9937840957 | |
parent | a5324a131f3816c8312e27a9dc928b8d56d8cf3b (diff) |
Add support to directly evaluate the product of two sparse matrices within a dense matrix.
-rw-r--r-- | Eigen/src/SparseCore/ConservativeSparseSparseProduct.h | 85 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseAssign.h | 8 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseMatrixBase.h | 2 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseProduct.h | 34 | ||||
-rw-r--r-- | test/sparse_product.cpp | 11 |
5 files changed, 132 insertions, 8 deletions
diff --git a/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h b/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h index a61ceb7cc..0f6835846 100644 --- a/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +++ b/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h @@ -1,7 +1,7 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // -// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr> +// Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr> // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -257,6 +257,89 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,R } // end namespace internal + +namespace internal { + +template<typename Lhs, typename Rhs, typename ResultType> +static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res) +{ + typedef typename remove_all<Lhs>::type::Scalar Scalar; + Index cols = rhs.outerSize(); + eigen_assert(lhs.outerSize() == rhs.innerSize()); + + evaluator<Lhs> lhsEval(lhs); + evaluator<Rhs> rhsEval(rhs); + + for (Index j=0; j<cols; ++j) + { + for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) + { + Scalar y = rhsIt.value(); + Index k = rhsIt.index(); + for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt) + { + Index i = lhsIt.index(); + Scalar x = lhsIt.value(); + res.coeffRef(i,j) += x * y; + } + } + } +} + + +} // end namespace internal + +namespace internal { + +template<typename Lhs, typename Rhs, typename ResultType, + int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor, + int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor> +struct sparse_sparse_to_dense_product_selector; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor> +{ + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) + { + internal::sparse_sparse_to_dense_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, res); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor> +{ + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) + { + typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix; + ColMajorMatrix lhsCol(lhs); + internal::sparse_sparse_to_dense_product_impl<ColMajorMatrix,Rhs,ResultType>(lhsCol, rhs, res); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor> +{ + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) + { + typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix; + ColMajorMatrix rhsCol(rhs); + internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorMatrix,ResultType>(lhs, rhsCol, res); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor> +{ + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) + { + Transpose<ResultType> trRes(res); + internal::sparse_sparse_to_dense_product_impl<Rhs,Lhs,Transpose<ResultType> >(rhs, lhs, trRes); + } +}; + + +} // end namespace internal + } // end namespace Eigen #endif // EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H diff --git a/Eigen/src/SparseCore/SparseAssign.h b/Eigen/src/SparseCore/SparseAssign.h index e984bbdb3..4b663a59e 100644 --- a/Eigen/src/SparseCore/SparseAssign.h +++ b/Eigen/src/SparseCore/SparseAssign.h @@ -133,8 +133,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Sparse, Scalar> }; // Sparse to Dense assignment -template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar> -struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar> +template< typename DstXprType, typename SrcXprType, typename Functor> +struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense> { static void run(DstXprType &dst, const SrcXprType &src, const Functor &func) { @@ -149,8 +149,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar> } }; -template< typename DstXprType, typename SrcXprType, typename Scalar> -struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense, Scalar> +template< typename DstXprType, typename SrcXprType> +struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense> { static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &) { diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h index 4e720904e..38eb1c37a 100644 --- a/Eigen/src/SparseCore/SparseMatrixBase.h +++ b/Eigen/src/SparseCore/SparseMatrixBase.h @@ -281,7 +281,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived> // sparse * sparse template<typename OtherDerived> - const Product<Derived,OtherDerived> + const Product<Derived,OtherDerived,AliasFreeProduct> operator*(const SparseMatrixBase<OtherDerived> &other) const; // sparse * dense diff --git a/Eigen/src/SparseCore/SparseProduct.h b/Eigen/src/SparseCore/SparseProduct.h index da8919ecc..26680b7a7 100644 --- a/Eigen/src/SparseCore/SparseProduct.h +++ b/Eigen/src/SparseCore/SparseProduct.h @@ -25,10 +25,10 @@ namespace Eigen { * */ template<typename Derived> template<typename OtherDerived> -inline const Product<Derived,OtherDerived> +inline const Product<Derived,OtherDerived,AliasFreeProduct> SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const { - return Product<Derived,OtherDerived>(derived(), other.derived()); + return Product<Derived,OtherDerived,AliasFreeProduct>(derived(), other.derived()); } namespace internal { @@ -61,6 +61,36 @@ struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, SparseShape, Produc : public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType> {}; +// Dense = sparse * sparse +template< typename DstXprType, typename Lhs, typename Rhs, int Options/*, typename Scalar*/> +struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense/*, + typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/> +{ + typedef Product<Lhs,Rhs,Options> SrcXprType; + static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &) + { + dst.setZero(); + dst += src; + } +}; + +// Dense += sparse * sparse +template< typename DstXprType, typename Lhs, typename Rhs, int Options> +struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::add_assign_op<typename DstXprType::Scalar>, Sparse2Dense/*, + typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/> +{ + typedef Product<Lhs,Rhs,Options> SrcXprType; + static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<typename DstXprType::Scalar> &) + { + typedef typename nested_eval<Lhs,Dynamic>::type LhsNested; + typedef typename nested_eval<Rhs,Dynamic>::type RhsNested; + LhsNested lhsNested(src.lhs()); + RhsNested rhsNested(src.rhs()); + internal::sparse_sparse_to_dense_product_selector<typename remove_all<LhsNested>::type, + typename remove_all<RhsNested>::type, DstXprType>::run(lhsNested,rhsNested,dst); + } +}; + template<typename Lhs, typename Rhs, int Options> struct evaluator<SparseView<Product<Lhs, Rhs, Options> > > : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject> diff --git a/test/sparse_product.cpp b/test/sparse_product.cpp index f1e5b8e4c..8c83f08d7 100644 --- a/test/sparse_product.cpp +++ b/test/sparse_product.cpp @@ -76,6 +76,17 @@ template<typename SparseMatrixType> void sparse_product() VERIFY_IS_APPROX(m4=(m2t.transpose()*m3t.transpose()).pruned(0), refMat4=refMat2t.transpose()*refMat3t.transpose()); VERIFY_IS_APPROX(m4=(m2*m3t.transpose()).pruned(0), refMat4=refMat2*refMat3t.transpose()); + // dense ?= sparse * sparse + VERIFY_IS_APPROX(dm4 =m2*m3, refMat4 =refMat2*refMat3); + VERIFY_IS_APPROX(dm4+=m2*m3, refMat4+=refMat2*refMat3); + VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3, refMat4 =refMat2t.transpose()*refMat3); + VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3, refMat4+=refMat2t.transpose()*refMat3); + VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3t.transpose(), refMat4 =refMat2t.transpose()*refMat3t.transpose()); + VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3t.transpose(), refMat4+=refMat2t.transpose()*refMat3t.transpose()); + VERIFY_IS_APPROX(dm4 =m2*m3t.transpose(), refMat4 =refMat2*refMat3t.transpose()); + VERIFY_IS_APPROX(dm4+=m2*m3t.transpose(), refMat4+=refMat2*refMat3t.transpose()); + VERIFY_IS_APPROX(dm4 = m2*m3*s1, refMat4 = refMat2*refMat3*s1); + // test aliasing m4 = m2; refMat4 = refMat2; VERIFY_IS_APPROX(m4=m4*m3, refMat4=refMat4*refMat3); |