aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/SparseCore/SparseSparseProductWithPruning.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/SparseCore/SparseSparseProductWithPruning.h')
-rw-r--r--Eigen/src/SparseCore/SparseSparseProductWithPruning.h15
1 files changed, 8 insertions, 7 deletions
diff --git a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h
index 773d8110c..9bfdb20c5 100644
--- a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h
+++ b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h
@@ -47,9 +47,12 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
AmbiVector<Scalar,Index> tempVector(rows);
// estimate the number of non zero entries
- float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
- float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
- float ratioRes = (std::min)(ratioLhs * avgNnzPerRhsColumn, 1.f);
+ // given a rhs column containing Y non zeros, we assume that the respective Y columns
+ // of the lhs differs in average of one non zeros, thus the number of non zeros for
+ // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
+ // per column of the lhs.
+ // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
+ Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
// mimics a resizeByInnerOuter:
if(ResultType::IsRowMajor)
@@ -57,13 +60,11 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
else
res.resize(rows, cols);
- res.reserve(Index(ratioRes*rows*cols));
+ res.reserve(estimated_nnz_prod);
for (Index j=0; j<cols; ++j)
{
// let's do a more accurate determination of the nnz ratio for the current column j of res
- //float ratioColRes = (std::min)(ratioLhs * rhs.innerNonZeros(j), 1.f);
- // FIXME find a nice way to get the number of nonzeros of a sub matrix (here an inner vector)
- float ratioColRes = ratioRes;
+ double ratioColRes = (double(rhs.col(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
tempVector.init(ratioColRes);
tempVector.setZero();
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)