diff options
-rw-r--r-- | Eigen/SparseCore | 4 | ||||
-rw-r--r-- | Eigen/src/SparseCore/ConservativeSparseSparseProduct.h | 13 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseMatrixBase.h | 11 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseProduct.h | 64 |
4 files changed, 88 insertions, 4 deletions
diff --git a/Eigen/SparseCore b/Eigen/SparseCore index b338950ca..0c91c3b59 100644 --- a/Eigen/SparseCore +++ b/Eigen/SparseCore @@ -49,12 +49,12 @@ struct Sparse {}; #include "src/SparseCore/SparseRedux.h" #include "src/SparseCore/SparseView.h" #include "src/SparseCore/SparseDiagonalProduct.h" +#include "src/SparseCore/ConservativeSparseSparseProduct.h" +#include "src/SparseCore/SparseProduct.h" #ifndef EIGEN_TEST_EVALUATORS #include "src/SparseCore/SparsePermutation.h" #include "src/SparseCore/SparseFuzzy.h" -#include "src/SparseCore/ConservativeSparseSparseProduct.h" #include "src/SparseCore/SparseSparseProductWithPruning.h" -#include "src/SparseCore/SparseProduct.h" #include "src/SparseCore/SparseDenseProduct.h" #include "src/SparseCore/SparseTriangularView.h" #include "src/SparseCore/SparseSelfAdjointView.h" diff --git a/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h b/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h index 5c320e2d2..193d71ca9 100644 --- a/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +++ b/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h @@ -36,6 +36,11 @@ static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& r // per column of the lhs. // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros(); + +#ifdef EIGEN_TEST_EVALUATORS + typename evaluator<Lhs>::type lhsEval(lhs); + typename evaluator<Rhs>::type rhsEval(rhs); +#endif res.setZero(); res.reserve(Index(estimated_nnz_prod)); @@ -45,11 +50,19 @@ static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& r res.startVec(j); Index nnz = 0; +#ifndef EIGEN_TEST_EVALUATORS for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) +#else + for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) +#endif { Scalar y = rhsIt.value(); Index k = rhsIt.index(); +#ifndef EIGEN_TEST_EVALUATORS for (typename Lhs::InnerIterator lhsIt(lhs, k); lhsIt; ++lhsIt) +#else + for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt) +#endif { Index i = lhsIt.index(); Scalar x = lhsIt.value(); diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h index c71244d3e..cebf3990d 100644 --- a/Eigen/src/SparseCore/SparseMatrixBase.h +++ b/Eigen/src/SparseCore/SparseMatrixBase.h @@ -190,8 +190,10 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived> public: +#ifndef EIGEN_TEST_EVALUATORS template<typename Lhs, typename Rhs> inline Derived& operator=(const SparseSparseProduct<Lhs,Rhs>& product); +#endif friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m) { @@ -264,12 +266,12 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived> EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE cwiseProduct(const MatrixBase<OtherDerived> &other) const; +#ifndef EIGEN_TEST_EVALUATORS // sparse * sparse template<typename OtherDerived> const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type operator*(const SparseMatrixBase<OtherDerived> &other) const; - -#ifndef EIGEN_TEST_EVALUATORS + // sparse * diagonal template<typename OtherDerived> const SparseDiagonalProduct<Derived,OtherDerived> @@ -292,6 +294,11 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived> const Product<OtherDerived,Derived> operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs) { return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); } + + // sparse * sparse + template<typename OtherDerived> + const Product<Derived,OtherDerived> + operator*(const SparseMatrixBase<OtherDerived> &other) const; #endif // EIGEN_TEST_EVALUATORS /** dense * sparse (return a dense object unless it is an outer product) */ diff --git a/Eigen/src/SparseCore/SparseProduct.h b/Eigen/src/SparseCore/SparseProduct.h index cf7663070..52c452f92 100644 --- a/Eigen/src/SparseCore/SparseProduct.h +++ b/Eigen/src/SparseCore/SparseProduct.h @@ -12,6 +12,8 @@ namespace Eigen { +#ifndef EIGEN_TEST_EVALUATORS + template<typename Lhs, typename Rhs> struct SparseSparseProductReturnType { @@ -183,6 +185,68 @@ SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other return typename SparseSparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); } +#else // EIGEN_TEST_EVALUATORS + + +/** \returns an expression of the product of two sparse matrices. + * By default a conservative product preserving the symbolic non zeros is performed. + * The automatic pruning of the small values can be achieved by calling the pruned() function + * in which case a totally different product algorithm is employed: + * \code + * C = (A*B).pruned(); // supress numerical zeros (exact) + * C = (A*B).pruned(ref); + * C = (A*B).pruned(ref,epsilon); + * \endcode + * where \c ref is a meaningful non zero reference value. + * */ +template<typename Derived> +template<typename OtherDerived> +inline const Product<Derived,OtherDerived> +SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const +{ + return Product<Derived,OtherDerived>(derived(), other.derived()); +} + +namespace internal { + +template<typename Lhs, typename Rhs, int ProductType> +struct generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) + { + typedef typename nested_eval<Lhs,Dynamic>::type LhsNested; + typedef typename nested_eval<Rhs,Dynamic>::type RhsNested; + LhsNested lhsNested(lhs); + RhsNested rhsNested(rhs); + internal::conservative_sparse_sparse_product_selector<typename remove_all<LhsNested>::type, + typename remove_all<RhsNested>::type, Dest>::run(lhsNested,rhsNested,dst); + } +}; + +template<typename Lhs, typename Rhs, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar> + : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type +{ + typedef Product<Lhs, Rhs, DefaultProduct> XprType; + typedef typename XprType::PlainObject PlainObject; + typedef typename evaluator<PlainObject>::type Base; + + product_evaluator(const XprType& xpr) + : m_result(xpr.rows(), xpr.cols()) + { + ::new (static_cast<Base*>(this)) Base(m_result); + generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs()); + } + +protected: + PlainObject m_result; +}; + +} // end namespace internal + +#endif // EIGEN_TEST_EVALUATORS + } // end namespace Eigen #endif // EIGEN_SPARSEPRODUCT_H |