aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-07-01 13:18:56 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-07-01 13:18:56 +0200
commit746d2db6ed5d1bb104757f2170e2018eda524a12 (patch)
tree5978687be5f9b87db0c4a384d22033887febbb31 /Eigen
parent441f97b2df8465cb8d5c601e9f1ed324af71491e (diff)
Implement evaluators for sparse * sparse with auto pruning.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/SparseCore2
-rw-r--r--Eigen/src/SparseCore/SparseMatrixBase.h3
-rw-r--r--Eigen/src/SparseCore/SparseProduct.h30
-rw-r--r--Eigen/src/SparseCore/SparseSparseProductWithPruning.h63
-rw-r--r--Eigen/src/SparseCore/SparseView.h26
5 files changed, 120 insertions, 4 deletions
diff --git a/Eigen/SparseCore b/Eigen/SparseCore
index 0c91c3b59..f74df3038 100644
--- a/Eigen/SparseCore
+++ b/Eigen/SparseCore
@@ -50,11 +50,11 @@ struct Sparse {};
#include "src/SparseCore/SparseView.h"
#include "src/SparseCore/SparseDiagonalProduct.h"
#include "src/SparseCore/ConservativeSparseSparseProduct.h"
+#include "src/SparseCore/SparseSparseProductWithPruning.h"
#include "src/SparseCore/SparseProduct.h"
#ifndef EIGEN_TEST_EVALUATORS
#include "src/SparseCore/SparsePermutation.h"
#include "src/SparseCore/SparseFuzzy.h"
-#include "src/SparseCore/SparseSparseProductWithPruning.h"
#include "src/SparseCore/SparseDenseProduct.h"
#include "src/SparseCore/SparseTriangularView.h"
#include "src/SparseCore/SparseSelfAdjointView.h"
diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h
index cebf3990d..3a81916fb 100644
--- a/Eigen/src/SparseCore/SparseMatrixBase.h
+++ b/Eigen/src/SparseCore/SparseMatrixBase.h
@@ -394,6 +394,9 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
{ return typename internal::eval<Derived>::type(derived()); }
Scalar sum() const;
+
+ inline const SparseView<Derived>
+ pruned(const Scalar& reference = Scalar(0), const RealScalar& epsilon = NumTraits<Scalar>::dummy_precision()) const;
protected:
diff --git a/Eigen/src/SparseCore/SparseProduct.h b/Eigen/src/SparseCore/SparseProduct.h
index 52c452f92..8b9578836 100644
--- a/Eigen/src/SparseCore/SparseProduct.h
+++ b/Eigen/src/SparseCore/SparseProduct.h
@@ -242,7 +242,37 @@ struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseSh
protected:
PlainObject m_result;
};
+
+template<typename Lhs, typename Rhs, int Options>
+struct evaluator<SparseView<Product<Lhs, Rhs, Options> > >
+ : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
+{
+ typedef SparseView<Product<Lhs, Rhs, Options> > XprType;
+ typedef typename XprType::PlainObject PlainObject;
+ typedef typename evaluator<PlainObject>::type Base;
+
+ typedef evaluator type;
+ typedef evaluator nestedType;
+
+ evaluator(const XprType& xpr)
+ : m_result(xpr.rows(), xpr.cols())
+ {
+ using std::abs;
+ ::new (static_cast<Base*>(this)) Base(m_result);
+ typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
+ typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
+ LhsNested lhsNested(xpr.nestedExpression().lhs());
+ RhsNested rhsNested(xpr.nestedExpression().rhs());
+
+ internal::sparse_sparse_product_with_pruning_selector<typename remove_all<LhsNested>::type,
+ typename remove_all<RhsNested>::type, PlainObject>::run(lhsNested,rhsNested,m_result,
+ abs(xpr.reference())*xpr.epsilon());
+ }
+protected:
+ PlainObject m_result;
+};
+
} // end namespace internal
#endif // EIGEN_TEST_EVALUATORS
diff --git a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h
index fcc18f5c9..c33ec6bfd 100644
--- a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h
+++ b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h
@@ -46,6 +46,11 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
res.resize(cols, rows);
else
res.resize(rows, cols);
+
+ #ifdef EIGEN_TEST_EVALUATORS
+ typename evaluator<Lhs>::type lhsEval(lhs);
+ typename evaluator<Rhs>::type rhsEval(rhs);
+ #endif
res.reserve(estimated_nnz_prod);
double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
@@ -56,12 +61,20 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
// let's do a more accurate determination of the nnz ratio for the current column j of res
tempVector.init(ratioColRes);
tempVector.setZero();
+#ifndef EIGEN_TEST_EVALUATORS
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
+#else
+ for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
+#endif
{
// FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
tempVector.restart();
Scalar x = rhsIt.value();
+#ifndef EIGEN_TEST_EVALUATORS
for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
+#else
+ for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
+#endif
{
tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
}
@@ -140,8 +153,58 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R
}
};
+#ifndef EIGEN_TEST_EVALUATORS
// NOTE the 2 others cases (col row *) must never occur since they are caught
// by ProductReturnType which transforms it to (col col *) by evaluating rhs.
+#else
+template<typename Lhs, typename Rhs, typename ResultType>
+struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
+{
+ typedef typename ResultType::RealScalar RealScalar;
+ static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
+ {
+ typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixLhs;
+ RowMajorMatrixLhs rowLhs(lhs);
+ sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
+ }
+};
+
+template<typename Lhs, typename Rhs, typename ResultType>
+struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
+{
+ typedef typename ResultType::RealScalar RealScalar;
+ static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
+ {
+ typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixRhs;
+ RowMajorMatrixRhs rowRhs(rhs);
+ sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
+ }
+};
+
+template<typename Lhs, typename Rhs, typename ResultType>
+struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
+{
+ typedef typename ResultType::RealScalar RealScalar;
+ static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
+ {
+ typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs;
+ ColMajorMatrixRhs colRhs(rhs);
+ internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
+ }
+};
+
+template<typename Lhs, typename Rhs, typename ResultType>
+struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
+{
+ typedef typename ResultType::RealScalar RealScalar;
+ static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
+ {
+ typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs;
+ ColMajorMatrixLhs colLhs(lhs);
+ internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
+ }
+};
+#endif
} // end namespace internal
diff --git a/Eigen/src/SparseCore/SparseView.h b/Eigen/src/SparseCore/SparseView.h
index 96d0a849c..7bffbb9cd 100644
--- a/Eigen/src/SparseCore/SparseView.h
+++ b/Eigen/src/SparseCore/SparseView.h
@@ -233,10 +233,30 @@ struct unary_evaluator<SparseView<ArgType>, IndexBased>
#endif // EIGEN_TEST_EVALUATORS
template<typename Derived>
-const SparseView<Derived> MatrixBase<Derived>::sparseView(const Scalar& m_reference,
- const typename NumTraits<Scalar>::Real& m_epsilon) const
+const SparseView<Derived> MatrixBase<Derived>::sparseView(const Scalar& reference,
+ const typename NumTraits<Scalar>::Real& epsilon) const
{
- return SparseView<Derived>(derived(), m_reference, m_epsilon);
+ return SparseView<Derived>(derived(), reference, epsilon);
+}
+
+/** \returns an expression of \c *this with values smaller than
+ * \a reference * \a epsilon are removed.
+ *
+ * This method is typically used in conjunction with the product of two sparse matrices
+ * to automatically prune the smallest values as follows:
+ * \code
+ * C = (A*B).pruned(); // suppress numerical zeros (exact)
+ * C = (A*B).pruned(ref);
+ * C = (A*B).pruned(ref,epsilon);
+ * \endcode
+ * where \c ref is a meaningful non zero reference value.
+ * */
+template<typename Derived>
+const SparseView<Derived>
+SparseMatrixBase<Derived>::pruned(const Scalar& reference,
+ const RealScalar& epsilon) const
+{
+ return SparseView<Derived>(derived(), reference, epsilon);
}
} // end namespace Eigen