From 3573a107120c3762c1f9b6b17f62d817ee1a9ec0 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 17 Feb 2014 13:46:17 +0100 Subject: Fix support for row (resp. column) of a column-major (resp. row-major) sparse matrix --- Eigen/src/SparseCore/SparseBlock.h | 133 +++++++++++++++++++++++++++++-------- 1 file changed, 107 insertions(+), 26 deletions(-) (limited to 'Eigen/src/SparseCore/SparseBlock.h') diff --git a/Eigen/src/SparseCore/SparseBlock.h b/Eigen/src/SparseCore/SparseBlock.h index d0436aa7c..4fa98c01c 100644 --- a/Eigen/src/SparseCore/SparseBlock.h +++ b/Eigen/src/SparseCore/SparseBlock.h @@ -335,6 +335,14 @@ const Block SparseMatrixBase::inner } +namespace internal { + +template< typename XprType, int BlockRows, int BlockCols, bool InnerPanel, + bool OuterVector = (BlockCols==1 && XprType::IsRowMajor) || (BlockRows==1 && !XprType::IsRowMajor)> +class GenericSparseBlockInnerIteratorImpl; + +} + /** Generic implementation of sparse Block expression. * Real-only. */ @@ -354,8 +362,8 @@ public: : m_matrix(xpr), m_startRow( (BlockRows==1) && (BlockCols==XprType::ColsAtCompileTime) ? i : 0), m_startCol( (BlockRows==XprType::RowsAtCompileTime) && (BlockCols==1) ? i : 0), - m_blockRows(xpr.rows()), - m_blockCols(xpr.cols()) + m_blockRows(BlockRows==1 ? 1 : xpr.rows()), + m_blockCols(BlockCols==1 ? 1 : xpr.cols()) {} /** Dynamic-size constructor @@ -394,29 +402,8 @@ public: inline const _MatrixTypeNested& nestedExpression() const { return m_matrix; } - class InnerIterator : public _MatrixTypeNested::InnerIterator - { - typedef typename _MatrixTypeNested::InnerIterator Base; - const BlockType& m_block; - Index m_end; - public: - - EIGEN_STRONG_INLINE InnerIterator(const BlockType& block, Index outer) - : Base(block.derived().nestedExpression(), outer + (IsRowMajor ? block.m_startRow.value() : block.m_startCol.value())), - m_block(block), - m_end(IsRowMajor ? block.m_startCol.value()+block.m_blockCols.value() : block.m_startRow.value()+block.m_blockRows.value()) - { - while( (Base::operator bool()) && (Base::index() < (IsRowMajor ? m_block.m_startCol.value() : m_block.m_startRow.value())) ) - Base::operator++(); - } - - inline Index index() const { return Base::index() - (IsRowMajor ? m_block.m_startCol.value() : m_block.m_startRow.value()); } - inline Index outer() const { return Base::outer() - (IsRowMajor ? m_block.m_startRow.value() : m_block.m_startCol.value()); } - inline Index row() const { return Base::row() - m_block.m_startRow.value(); } - inline Index col() const { return Base::col() - m_block.m_startCol.value(); } - - inline operator bool() const { return Base::operator bool() && Base::index() < m_end; } - }; + typedef internal::GenericSparseBlockInnerIteratorImpl InnerIterator; + class ReverseInnerIterator : public _MatrixTypeNested::ReverseInnerIterator { typedef typename _MatrixTypeNested::ReverseInnerIterator Base; @@ -441,7 +428,7 @@ public: inline operator bool() const { return Base::operator bool() && Base::index() >= m_begin; } }; protected: - friend class InnerIterator; + friend class internal::GenericSparseBlockInnerIteratorImpl; friend class ReverseInnerIterator; EIGEN_INHERIT_ASSIGNMENT_OPERATORS(BlockImpl) @@ -454,6 +441,100 @@ public: }; +namespace internal { + template + class GenericSparseBlockInnerIteratorImpl : public Block::_MatrixTypeNested::InnerIterator + { + typedef Block BlockType; + enum { + IsRowMajor = BlockType::IsRowMajor + }; + typedef typename BlockType::_MatrixTypeNested _MatrixTypeNested; + typedef typename BlockType::Index Index; + typedef typename _MatrixTypeNested::InnerIterator Base; + const BlockType& m_block; + Index m_end; + public: + + EIGEN_STRONG_INLINE GenericSparseBlockInnerIteratorImpl(const BlockType& block, Index outer) + : Base(block.derived().nestedExpression(), outer + (IsRowMajor ? block.m_startRow.value() : block.m_startCol.value())), + m_block(block), + m_end(IsRowMajor ? block.m_startCol.value()+block.m_blockCols.value() : block.m_startRow.value()+block.m_blockRows.value()) + { + while( (Base::operator bool()) && (Base::index() < (IsRowMajor ? m_block.m_startCol.value() : m_block.m_startRow.value())) ) + Base::operator++(); + } + + inline Index index() const { return Base::index() - (IsRowMajor ? m_block.m_startCol.value() : m_block.m_startRow.value()); } + inline Index outer() const { return Base::outer() - (IsRowMajor ? m_block.m_startRow.value() : m_block.m_startCol.value()); } + inline Index row() const { return Base::row() - m_block.m_startRow.value(); } + inline Index col() const { return Base::col() - m_block.m_startCol.value(); } + + inline operator bool() const { return Base::operator bool() && Base::index() < m_end; } + }; + + // Row vector of a column-major sparse matrix or column of a row-major one. + template + class GenericSparseBlockInnerIteratorImpl + { + typedef Block BlockType; + enum { + IsRowMajor = BlockType::IsRowMajor + }; + typedef typename BlockType::_MatrixTypeNested _MatrixTypeNested; + typedef typename BlockType::Index Index; + typedef typename BlockType::Scalar Scalar; + const BlockType& m_block; + Index m_outerPos; + Index m_innerIndex; + Scalar m_value; + Index m_end; + public: + + EIGEN_STRONG_INLINE GenericSparseBlockInnerIteratorImpl(const BlockType& block, Index outer = 0) + : + m_block(block), + m_outerPos( (IsRowMajor ? block.m_startCol.value() : block.m_startRow.value()) - 1), // -1 so that operator++ finds the first non-zero entry + m_innerIndex(IsRowMajor ? block.m_startRow.value() : block.m_startCol.value()), + m_end(IsRowMajor ? block.m_startCol.value()+block.m_blockCols.value() : block.m_startRow.value()+block.m_blockRows.value()) + { + EIGEN_UNUSED_VARIABLE(outer); + eigen_assert(outer==0); + + ++(*this); + } + + inline Index index() const { return m_outerPos - (IsRowMajor ? m_block.m_startCol.value() : m_block.m_startRow.value()); } + inline Index outer() const { return 0; } + inline Index row() const { return IsRowMajor ? 0 : index(); } + inline Index col() const { return IsRowMajor ? index() : 0; } + + inline Scalar value() const { return m_value; } + + inline GenericSparseBlockInnerIteratorImpl& operator++() + { + // search next non-zero entry + while(m_outerPos