diff options
Diffstat (limited to 'Eigen/src/SparseCore')
-rw-r--r-- | Eigen/src/SparseCore/SparseSelfAdjointView.h | 2 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseTriangularView.h | 62 |
2 files changed, 56 insertions, 8 deletions
diff --git a/Eigen/src/SparseCore/SparseSelfAdjointView.h b/Eigen/src/SparseCore/SparseSelfAdjointView.h index fc6f56adc..09e960ae9 100644 --- a/Eigen/src/SparseCore/SparseSelfAdjointView.h +++ b/Eigen/src/SparseCore/SparseSelfAdjointView.h @@ -113,7 +113,7 @@ template<typename MatrixType, unsigned int UpLo> class SparseSelfAdjointView SparseSelfAdjointView& rankUpdate(const SparseMatrixBase<DerivedU>& u, Scalar alpha = Scalar(1)); /** \internal triggered by sparse_matrix = SparseSelfadjointView; */ - template<typename DestScalar> void evalTo(SparseMatrix<DestScalar,ColMajor,Index>& _dest) const + template<typename DestScalar,int StorageOrder> void evalTo(SparseMatrix<DestScalar,StorageOrder,Index>& _dest) const { internal::permute_symm_to_fullsymm<UpLo>(m_matrix, _dest); } diff --git a/Eigen/src/SparseCore/SparseTriangularView.h b/Eigen/src/SparseCore/SparseTriangularView.h index 0b2d06528..3c0c10242 100644 --- a/Eigen/src/SparseCore/SparseTriangularView.h +++ b/Eigen/src/SparseCore/SparseTriangularView.h @@ -37,9 +37,10 @@ struct traits<SparseTriangularView<MatrixType,Mode> > template<typename MatrixType, int Mode> class SparseTriangularView : public SparseMatrixBase<SparseTriangularView<MatrixType,Mode> > { - enum { SkipFirst = (Mode==Lower && !(MatrixType::Flags&RowMajorBit)) - || (Mode==Upper && (MatrixType::Flags&RowMajorBit)), - SkipLast = !SkipFirst + enum { SkipFirst = ((Mode&Lower) && !(MatrixType::Flags&RowMajorBit)) + || ((Mode&Upper) && (MatrixType::Flags&RowMajorBit)), + SkipLast = !SkipFirst, + HasUnitDiag = (Mode&UnitDiag) ? 1 : 0 }; public: @@ -81,19 +82,61 @@ class SparseTriangularView<MatrixType,Mode>::InnerIterator : public MatrixType:: public: EIGEN_STRONG_INLINE InnerIterator(const SparseTriangularView& view, Index outer) - : Base(view.nestedExpression(), outer) + : Base(view.nestedExpression(), outer), m_returnOne(false) { if(SkipFirst) - while((*this) && this->index()<outer) - ++(*this); + { + while((*this) && (HasUnitDiag ? this->index()<=outer : this->index()<outer)) + Base::operator++(); + if(HasUnitDiag) + m_returnOne = true; + } + else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer())) + { + if((!SkipFirst) && Base::operator bool()) + Base::operator++(); + m_returnOne = true; + } + } + + EIGEN_STRONG_INLINE InnerIterator& operator++() + { + if(HasUnitDiag && m_returnOne) + m_returnOne = false; + else + { + Base::operator++(); + if(HasUnitDiag && (!SkipFirst) && ((!Base::operator bool()) || Base::index()>=Base::outer())) + { + if((!SkipFirst) && Base::operator bool()) + Base::operator++(); + m_returnOne = true; + } + } + return *this; } + inline Index row() const { return Base::row(); } inline Index col() const { return Base::col(); } + inline Index index() const + { + if(HasUnitDiag && m_returnOne) return Base::outer(); + else return Base::index(); + } + inline Scalar value() const + { + if(HasUnitDiag && m_returnOne) return Scalar(1); + else return Base::value(); + } EIGEN_STRONG_INLINE operator bool() const { - return SkipFirst ? Base::operator bool() : (Base::operator bool() && this->index() <= this->outer()); + if(HasUnitDiag && m_returnOne) + return true; + return (SkipFirst ? Base::operator bool() : (Base::operator bool() && this->index() <= this->outer())); } + protected: + bool m_returnOne; }; template<typename MatrixType, int Mode> @@ -105,10 +148,15 @@ class SparseTriangularView<MatrixType,Mode>::ReverseInnerIterator : public Matri EIGEN_STRONG_INLINE ReverseInnerIterator(const SparseTriangularView& view, Index outer) : Base(view.nestedExpression(), outer) { + eigen_assert((!HasUnitDiag) && "ReverseInnerIterator does not support yet triangular views with a unit diagonal"); if(SkipLast) while((*this) && this->index()>outer) --(*this); } + + EIGEN_STRONG_INLINE InnerIterator& operator--() + { Base::operator--(); return *this; } + inline Index row() const { return Base::row(); } inline Index col() const { return Base::col(); } |