aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2012-10-03 09:06:19 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2012-10-03 09:06:19 +0200
commitfec6df1f7dc4da6de865cc024c45aee5dbb64d88 (patch)
tree1fc2a900f3e3bbf76e485cab82cf10afba0a8f51
parentf30ca7ed7e1c756fcc76389ddbc361486f7d8c42 (diff)
fix dense=sparse*diagonal (there was an issue in the values returned by the .outer() function of the related iterators)
-rw-r--r--Eigen/src/SparseCore/SparseDiagonalProduct.h12
-rw-r--r--test/sparse_product.cpp7
2 files changed, 17 insertions, 2 deletions
diff --git a/Eigen/src/SparseCore/SparseDiagonalProduct.h b/Eigen/src/SparseCore/SparseDiagonalProduct.h
index 095bf6863..ccba02124 100644
--- a/Eigen/src/SparseCore/SparseDiagonalProduct.h
+++ b/Eigen/src/SparseCore/SparseDiagonalProduct.h
@@ -126,11 +126,15 @@ class sparse_diagonal_product_inner_iterator_selector
SparseInnerVectorSet<Rhs,1>,
typename Lhs::DiagonalVectorType>::InnerIterator Base;
typedef typename Lhs::Index Index;
+ Index m_outer;
public:
inline sparse_diagonal_product_inner_iterator_selector(
const SparseDiagonalProductType& expr, Index outer)
- : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0)
+ : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0), m_outer(outer)
{}
+
+ inline Index outer() const { return m_outer; }
+ inline Index col() const { return m_outer; }
};
template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
@@ -160,11 +164,15 @@ class sparse_diagonal_product_inner_iterator_selector
SparseInnerVectorSet<Lhs,1>,
Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator Base;
typedef typename Lhs::Index Index;
+ Index m_outer;
public:
inline sparse_diagonal_product_inner_iterator_selector(
const SparseDiagonalProductType& expr, Index outer)
- : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0)
+ : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0), m_outer(outer)
{}
+
+ inline Index outer() const { return m_outer; }
+ inline Index row() const { return m_outer; }
};
} // end namespace internal
diff --git a/test/sparse_product.cpp b/test/sparse_product.cpp
index df660fe12..4eae263fa 100644
--- a/test/sparse_product.cpp
+++ b/test/sparse_product.cpp
@@ -123,6 +123,7 @@ template<typename SparseMatrixType> void sparse_product()
{
DenseMatrix refM2 = DenseMatrix::Zero(rows, cols);
DenseMatrix refM3 = DenseMatrix::Zero(rows, cols);
+ DenseMatrix d3 = DenseMatrix::Zero(rows, cols);
DiagonalMatrix<Scalar,Dynamic> d1(DenseVector::Random(cols));
DiagonalMatrix<Scalar,Dynamic> d2(DenseVector::Random(rows));
SparseMatrixType m2(rows, cols);
@@ -133,6 +134,12 @@ template<typename SparseMatrixType> void sparse_product()
VERIFY_IS_APPROX(m3=m2.transpose()*d2, refM3=refM2.transpose()*d2);
VERIFY_IS_APPROX(m3=d2*m2, refM3=d2*refM2);
VERIFY_IS_APPROX(m3=d1*m2.transpose(), refM3=d1*refM2.transpose());
+
+ // evaluate to a dense matrix to check the .row() and .col() iterator functions
+ VERIFY_IS_APPROX(d3=m2*d1, refM3=refM2*d1);
+ VERIFY_IS_APPROX(d3=m2.transpose()*d2, refM3=refM2.transpose()*d2);
+ VERIFY_IS_APPROX(d3=d2*m2, refM3=d2*refM2);
+ VERIFY_IS_APPROX(d3=d1*m2.transpose(), refM3=d1*refM2.transpose());
}
// test self adjoint products