diff options
author | Gael Guennebaud <g.gael@free.fr> | 2014-07-01 13:18:56 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2014-07-01 13:18:56 +0200 |
commit | 746d2db6ed5d1bb104757f2170e2018eda524a12 (patch) | |
tree | 5978687be5f9b87db0c4a384d22033887febbb31 /Eigen | |
parent | 441f97b2df8465cb8d5c601e9f1ed324af71491e (diff) |
Implement evaluators for sparse * sparse with auto pruning.
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/SparseCore | 2 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseMatrixBase.h | 3 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseProduct.h | 30 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseSparseProductWithPruning.h | 63 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseView.h | 26 |
5 files changed, 120 insertions, 4 deletions
diff --git a/Eigen/SparseCore b/Eigen/SparseCore index 0c91c3b59..f74df3038 100644 --- a/Eigen/SparseCore +++ b/Eigen/SparseCore @@ -50,11 +50,11 @@ struct Sparse {}; #include "src/SparseCore/SparseView.h" #include "src/SparseCore/SparseDiagonalProduct.h" #include "src/SparseCore/ConservativeSparseSparseProduct.h" +#include "src/SparseCore/SparseSparseProductWithPruning.h" #include "src/SparseCore/SparseProduct.h" #ifndef EIGEN_TEST_EVALUATORS #include "src/SparseCore/SparsePermutation.h" #include "src/SparseCore/SparseFuzzy.h" -#include "src/SparseCore/SparseSparseProductWithPruning.h" #include "src/SparseCore/SparseDenseProduct.h" #include "src/SparseCore/SparseTriangularView.h" #include "src/SparseCore/SparseSelfAdjointView.h" diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h index cebf3990d..3a81916fb 100644 --- a/Eigen/src/SparseCore/SparseMatrixBase.h +++ b/Eigen/src/SparseCore/SparseMatrixBase.h @@ -394,6 +394,9 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived> { return typename internal::eval<Derived>::type(derived()); } Scalar sum() const; + + inline const SparseView<Derived> + pruned(const Scalar& reference = Scalar(0), const RealScalar& epsilon = NumTraits<Scalar>::dummy_precision()) const; protected: diff --git a/Eigen/src/SparseCore/SparseProduct.h b/Eigen/src/SparseCore/SparseProduct.h index 52c452f92..8b9578836 100644 --- a/Eigen/src/SparseCore/SparseProduct.h +++ b/Eigen/src/SparseCore/SparseProduct.h @@ -242,7 +242,37 @@ struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseSh protected: PlainObject m_result; }; + +template<typename Lhs, typename Rhs, int Options> +struct evaluator<SparseView<Product<Lhs, Rhs, Options> > > + : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type +{ + typedef SparseView<Product<Lhs, Rhs, Options> > XprType; + typedef typename XprType::PlainObject PlainObject; + typedef typename evaluator<PlainObject>::type Base; + + typedef evaluator type; + typedef evaluator nestedType; + + evaluator(const XprType& xpr) + : m_result(xpr.rows(), xpr.cols()) + { + using std::abs; + ::new (static_cast<Base*>(this)) Base(m_result); + typedef typename nested_eval<Lhs,Dynamic>::type LhsNested; + typedef typename nested_eval<Rhs,Dynamic>::type RhsNested; + LhsNested lhsNested(xpr.nestedExpression().lhs()); + RhsNested rhsNested(xpr.nestedExpression().rhs()); + + internal::sparse_sparse_product_with_pruning_selector<typename remove_all<LhsNested>::type, + typename remove_all<RhsNested>::type, PlainObject>::run(lhsNested,rhsNested,m_result, + abs(xpr.reference())*xpr.epsilon()); + } +protected: + PlainObject m_result; +}; + } // end namespace internal #endif // EIGEN_TEST_EVALUATORS diff --git a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h index fcc18f5c9..c33ec6bfd 100644 --- a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +++ b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h @@ -46,6 +46,11 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r res.resize(cols, rows); else res.resize(rows, cols); + + #ifdef EIGEN_TEST_EVALUATORS + typename evaluator<Lhs>::type lhsEval(lhs); + typename evaluator<Rhs>::type rhsEval(rhs); + #endif res.reserve(estimated_nnz_prod); double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols()); @@ -56,12 +61,20 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r // let's do a more accurate determination of the nnz ratio for the current column j of res tempVector.init(ratioColRes); tempVector.setZero(); +#ifndef EIGEN_TEST_EVALUATORS for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) +#else + for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) +#endif { // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) tempVector.restart(); Scalar x = rhsIt.value(); +#ifndef EIGEN_TEST_EVALUATORS for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt) +#else + for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) +#endif { tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; } @@ -140,8 +153,58 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R } }; +#ifndef EIGEN_TEST_EVALUATORS // NOTE the 2 others cases (col row *) must never occur since they are caught // by ProductReturnType which transforms it to (col col *) by evaluating rhs. +#else +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixLhs; + RowMajorMatrixLhs rowLhs(lhs); + sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixRhs; + RowMajorMatrixRhs rowRhs(rhs); + sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs; + ColMajorMatrixRhs colRhs(rhs); + internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs; + ColMajorMatrixLhs colLhs(lhs); + internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance); + } +}; +#endif } // end namespace internal diff --git a/Eigen/src/SparseCore/SparseView.h b/Eigen/src/SparseCore/SparseView.h index 96d0a849c..7bffbb9cd 100644 --- a/Eigen/src/SparseCore/SparseView.h +++ b/Eigen/src/SparseCore/SparseView.h @@ -233,10 +233,30 @@ struct unary_evaluator<SparseView<ArgType>, IndexBased> #endif // EIGEN_TEST_EVALUATORS template<typename Derived> -const SparseView<Derived> MatrixBase<Derived>::sparseView(const Scalar& m_reference, - const typename NumTraits<Scalar>::Real& m_epsilon) const +const SparseView<Derived> MatrixBase<Derived>::sparseView(const Scalar& reference, + const typename NumTraits<Scalar>::Real& epsilon) const { - return SparseView<Derived>(derived(), m_reference, m_epsilon); + return SparseView<Derived>(derived(), reference, epsilon); +} + +/** \returns an expression of \c *this with values smaller than + * \a reference * \a epsilon are removed. + * + * This method is typically used in conjunction with the product of two sparse matrices + * to automatically prune the smallest values as follows: + * \code + * C = (A*B).pruned(); // suppress numerical zeros (exact) + * C = (A*B).pruned(ref); + * C = (A*B).pruned(ref,epsilon); + * \endcode + * where \c ref is a meaningful non zero reference value. + * */ +template<typename Derived> +const SparseView<Derived> +SparseMatrixBase<Derived>::pruned(const Scalar& reference, + const RealScalar& epsilon) const +{ + return SparseView<Derived>(derived(), reference, epsilon); } } // end namespace Eigen |