diff options
Diffstat (limited to 'Eigen/src/SparseCore/SparseSparseProductWithPruning.h')
-rw-r--r-- | Eigen/src/SparseCore/SparseSparseProductWithPruning.h | 58 |
1 files changed, 53 insertions, 5 deletions
diff --git a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h index fcc18f5c9..f291f8cef 100644 --- a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +++ b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h @@ -1,7 +1,7 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // -// Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr> +// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr> // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -46,6 +46,9 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r res.resize(cols, rows); else res.resize(rows, cols); + + typename evaluator<Lhs>::type lhsEval(lhs); + typename evaluator<Rhs>::type rhsEval(rhs); res.reserve(estimated_nnz_prod); double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols()); @@ -56,12 +59,12 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r // let's do a more accurate determination of the nnz ratio for the current column j of res tempVector.init(ratioColRes); tempVector.setZero(); - for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) + for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) { // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) tempVector.restart(); Scalar x = rhsIt.value(); - for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt) + for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) { tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; } @@ -140,8 +143,53 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R } }; -// NOTE the 2 others cases (col row *) must never occur since they are caught -// by ProductReturnType which transforms it to (col col *) by evaluating rhs. +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixLhs; + RowMajorMatrixLhs rowLhs(lhs); + sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixRhs; + RowMajorMatrixRhs rowRhs(rhs); + sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs; + ColMajorMatrixRhs colRhs(rhs); + internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs; + ColMajorMatrixLhs colLhs(lhs); + internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance); + } +}; } // end namespace internal |