diff options
author | 2014-07-01 13:18:56 +0200 | |
---|---|---|
committer | 2014-07-01 13:18:56 +0200 | |
commit | 746d2db6ed5d1bb104757f2170e2018eda524a12 (patch) | |
tree | 5978687be5f9b87db0c4a384d22033887febbb31 /Eigen/src/SparseCore/SparseSparseProductWithPruning.h | |
parent | 441f97b2df8465cb8d5c601e9f1ed324af71491e (diff) |
Implement evaluators for sparse * sparse with auto pruning.
Diffstat (limited to 'Eigen/src/SparseCore/SparseSparseProductWithPruning.h')
-rw-r--r-- | Eigen/src/SparseCore/SparseSparseProductWithPruning.h | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h index fcc18f5c9..c33ec6bfd 100644 --- a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +++ b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h @@ -46,6 +46,11 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r res.resize(cols, rows); else res.resize(rows, cols); + + #ifdef EIGEN_TEST_EVALUATORS + typename evaluator<Lhs>::type lhsEval(lhs); + typename evaluator<Rhs>::type rhsEval(rhs); + #endif res.reserve(estimated_nnz_prod); double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols()); @@ -56,12 +61,20 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r // let's do a more accurate determination of the nnz ratio for the current column j of res tempVector.init(ratioColRes); tempVector.setZero(); +#ifndef EIGEN_TEST_EVALUATORS for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) +#else + for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) +#endif { // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) tempVector.restart(); Scalar x = rhsIt.value(); +#ifndef EIGEN_TEST_EVALUATORS for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt) +#else + for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) +#endif { tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; } @@ -140,8 +153,58 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R } }; +#ifndef EIGEN_TEST_EVALUATORS // NOTE the 2 others cases (col row *) must never occur since they are caught // by ProductReturnType which transforms it to (col col *) by evaluating rhs. +#else +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixLhs; + RowMajorMatrixLhs rowLhs(lhs); + sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixRhs; + RowMajorMatrixRhs rowRhs(rhs); + sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs; + ColMajorMatrixRhs colRhs(rhs); + internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs; + ColMajorMatrixLhs colLhs(lhs); + internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance); + } +}; +#endif } // end namespace internal |