diff options
author | Felipe Attanasio <oraculoide@gmail.com> | 2020-05-14 22:38:20 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-05-14 22:38:20 +0000 |
commit | d640276d31a7dea9207a68a061a6fa7c9fdf50e5 (patch) | |
tree | a25c77d7c781b4876313c854a9f677a9597eddf5 | |
parent | fa8fd4b4d57323384644394c651ca106d299695f (diff) |
Added support for reverse iterators for Vectorwise operations.
-rw-r--r-- | Eigen/src/Core/StlIterators.h | 129 | ||||
-rw-r--r-- | Eigen/src/Core/VectorwiseOp.h | 36 | ||||
-rw-r--r-- | Eigen/src/Core/util/ForwardDeclarations.h | 1 | ||||
-rw-r--r-- | test/stl_iterators.cpp | 15 |
4 files changed, 167 insertions, 14 deletions
diff --git a/Eigen/src/Core/StlIterators.h b/Eigen/src/Core/StlIterators.h index 0d8bd1aa3..8584e0e03 100644 --- a/Eigen/src/Core/StlIterators.h +++ b/Eigen/src/Core/StlIterators.h @@ -93,6 +93,85 @@ protected: Index m_index; }; +template<typename Derived> +class indexed_based_stl_reverse_iterator_base +{ +protected: + typedef indexed_based_stl_iterator_traits<Derived> traits; + typedef typename traits::XprType XprType; + typedef indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator> non_const_iterator; + typedef indexed_based_stl_reverse_iterator_base<typename traits::const_iterator> const_iterator; + typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator; + // NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class: + friend class indexed_based_stl_reverse_iterator_base<typename traits::const_iterator>; + friend class indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator>; +public: + typedef Index difference_type; + typedef std::random_access_iterator_tag iterator_category; + + indexed_based_stl_reverse_iterator_base() : mp_xpr(0), m_index(0) {} + indexed_based_stl_reverse_iterator_base(XprType& xpr, Index index) : mp_xpr(&xpr), m_index(index) {} + + indexed_based_stl_reverse_iterator_base(const non_const_iterator& other) + : mp_xpr(other.mp_xpr), m_index(other.m_index) + {} + + indexed_based_stl_reverse_iterator_base& operator=(const non_const_iterator& other) + { + mp_xpr = other.mp_xpr; + m_index = other.m_index; + return *this; + } + + Derived& operator++() { --m_index; return derived(); } + Derived& operator--() { ++m_index; return derived(); } + + Derived operator++(int) { Derived prev(derived()); operator++(); return prev;} + Derived operator--(int) { Derived prev(derived()); operator--(); return prev;} + + friend Derived operator+(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret += b; return ret; } + friend Derived operator-(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret -= b; return ret; } + friend Derived operator+(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret += a; return ret; } + friend Derived operator-(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret -= a; return ret; } + + Derived& operator+=(Index b) { m_index -= b; return derived(); } + Derived& operator-=(Index b) { m_index += b; return derived(); } + + difference_type operator-(const indexed_based_stl_reverse_iterator_base& other) const + { + eigen_assert(mp_xpr == other.mp_xpr); + return other.m_index - m_index; + } + + difference_type operator-(const other_iterator& other) const + { + eigen_assert(mp_xpr == other.mp_xpr); + return other.m_index - m_index; + } + + bool operator==(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; } + bool operator!=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; } + bool operator< (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; } + bool operator<=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; } + bool operator> (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; } + bool operator>=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; } + + bool operator==(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; } + bool operator!=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; } + bool operator< (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; } + bool operator<=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; } + bool operator> (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; } + bool operator>=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; } + +protected: + + Derived& derived() { return static_cast<Derived&>(*this); } + const Derived& derived() const { return static_cast<const Derived&>(*this); } + + XprType *mp_xpr; + Index m_index; +}; + template<typename XprType> class pointer_based_stl_iterator { @@ -267,6 +346,54 @@ public: pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); } }; +template<typename _XprType, DirectionType Direction> +struct indexed_based_stl_iterator_traits<subvector_stl_reverse_iterator<_XprType,Direction> > +{ + typedef _XprType XprType; + typedef subvector_stl_reverse_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator; + typedef subvector_stl_reverse_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator; +}; + +template<typename XprType, DirectionType Direction> +class subvector_stl_reverse_iterator : public indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator<XprType,Direction> > +{ +protected: + + enum { is_lvalue = internal::is_lvalue<XprType>::value }; + + typedef indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator> Base; + using Base::m_index; + using Base::mp_xpr; + + typedef typename internal::conditional<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr>::type SubVectorType; + typedef typename internal::conditional<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr>::type ConstSubVectorType; + + +public: + typedef typename internal::conditional<bool(is_lvalue), SubVectorType, ConstSubVectorType>::type reference; + typedef typename reference::PlainObject value_type; + +private: + class subvector_stl_reverse_iterator_ptr + { + public: + subvector_stl_reverse_iterator_ptr(const reference &subvector) : m_subvector(subvector) {} + reference* operator->() { return &m_subvector; } + private: + reference m_subvector; + }; +public: + + typedef subvector_stl_reverse_iterator_ptr pointer; + + subvector_stl_reverse_iterator() : Base() {} + subvector_stl_reverse_iterator(XprType& xpr, Index index) : Base(xpr,index) {} + + reference operator*() const { return (*mp_xpr).template subVector<Direction>(m_index); } + reference operator[](Index i) const { return (*mp_xpr).template subVector<Direction>(m_index+i); } + pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); } +}; + } // namespace internal @@ -328,4 +455,4 @@ inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::cend() co return const_iterator(derived(), size()); } -} // namespace Eigen +} // namespace Eigen
\ No newline at end of file diff --git a/Eigen/src/Core/VectorwiseOp.h b/Eigen/src/Core/VectorwiseOp.h index 865691b32..91a6c0353 100644 --- a/Eigen/src/Core/VectorwiseOp.h +++ b/Eigen/src/Core/VectorwiseOp.h @@ -279,27 +279,47 @@ template<typename ExpressionType, int Direction> class VectorwiseOp /** This is the const version of iterator (aka read-only) */ random_access_iterator_type const_iterator; #else - typedef internal::subvector_stl_iterator<ExpressionType, DirectionType(Direction)> iterator; - typedef internal::subvector_stl_iterator<const ExpressionType, DirectionType(Direction)> const_iterator; + typedef internal::subvector_stl_iterator<ExpressionType, DirectionType(Direction)> iterator; + typedef internal::subvector_stl_iterator<const ExpressionType, DirectionType(Direction)> const_iterator; + typedef internal::subvector_stl_reverse_iterator<ExpressionType, DirectionType(Direction)> reverse_iterator; + typedef internal::subvector_stl_reverse_iterator<const ExpressionType, DirectionType(Direction)> const_reverse_iterator; #endif /** returns an iterator to the first row (rowwise) or column (colwise) of the nested expression. * \sa end(), cbegin() */ - iterator begin() { return iterator (m_matrix, 0); } + iterator begin() { return iterator (m_matrix, 0); } /** const version of begin() */ - const_iterator begin() const { return const_iterator(m_matrix, 0); } + const_iterator begin() const { return const_iterator(m_matrix, 0); } /** const version of begin() */ - const_iterator cbegin() const { return const_iterator(m_matrix, 0); } + const_iterator cbegin() const { return const_iterator(m_matrix, 0); } + + /** returns a reverse iterator to the last row (rowwise) or column (colwise) of the nested expression. + * \sa rend(), crbegin() + */ + reverse_iterator rbegin() { return reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); } + /** const version of rbegin() */ + const_reverse_iterator rbegin() const { return const_reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); } + /** const version of rbegin() */ + const_reverse_iterator crbegin() const { return const_reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); } /** returns an iterator to the row (resp. column) following the last row (resp. column) of the nested expression * \sa begin(), cend() */ - iterator end() { return iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); } + iterator end() { return iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); } /** const version of end() */ - const_iterator end() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); } + const_iterator end() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); } /** const version of end() */ - const_iterator cend() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); } + const_iterator cend() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); } + + /** returns a reverse iterator to the row (resp. column) before the first row (resp. column) of the nested expression + * \sa begin(), cend() + */ + reverse_iterator rend() { return reverse_iterator (m_matrix, -1); } + /** const version of rend() */ + const_reverse_iterator rend() const { return const_reverse_iterator (m_matrix, -1); } + /** const version of rend() */ + const_reverse_iterator crend() const { return const_reverse_iterator (m_matrix, -1); } /** \returns a row or column vector expression of \c *this reduxed by \a func * diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index cd0bdb5a7..031a8ec3c 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -134,6 +134,7 @@ namespace internal { template<typename XprType> class generic_randaccess_stl_iterator; template<typename XprType> class pointer_based_stl_iterator; template<typename XprType, DirectionType Direction> class subvector_stl_iterator; +template<typename XprType, DirectionType Direction> class subvector_stl_reverse_iterator; template<typename DecompositionType> struct kernel_retval_base; template<typename DecompositionType> struct kernel_retval; template<typename DecompositionType> struct image_retval_base; diff --git a/test/stl_iterators.cpp b/test/stl_iterators.cpp index 25468eb49..997f8016c 100644 --- a/test/stl_iterators.cpp +++ b/test/stl_iterators.cpp @@ -431,22 +431,27 @@ void test_stl_iterators(int rows=Rows, int cols=Cols) { RowVectorType row = RowVectorType::Random(cols); A.rowwise() = row; - VERIFY( std::all_of(A.rowwise().begin(), A.rowwise().end(), [&row](typename ColMatrixType::RowXpr x) { return internal::isApprox(x.squaredNorm(),row.squaredNorm()); }) ); + VERIFY( std::all_of(A.rowwise().begin(), A.rowwise().end(), [&row](typename ColMatrixType::RowXpr x) { return internal::isApprox(x.squaredNorm(),row.squaredNorm()); }) ); + VERIFY( std::all_of(A.rowwise().rbegin(), A.rowwise().rend(), [&row](typename ColMatrixType::RowXpr x) { return internal::isApprox(x.squaredNorm(),row.squaredNorm()); }) ); VectorType col = VectorType::Random(rows); A.colwise() = col; - VERIFY( std::all_of(A.colwise().begin(), A.colwise().end(), [&col](typename ColMatrixType::ColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) ); - VERIFY( std::all_of(A.colwise().cbegin(), A.colwise().cend(), [&col](typename ColMatrixType::ConstColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) ); + VERIFY( std::all_of(A.colwise().begin(), A.colwise().end(), [&col](typename ColMatrixType::ColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) ); + VERIFY( std::all_of(A.colwise().rbegin(), A.colwise().rend(), [&col](typename ColMatrixType::ColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) ); + VERIFY( std::all_of(A.colwise().cbegin(), A.colwise().cend(), [&col](typename ColMatrixType::ConstColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) ); + VERIFY( std::all_of(A.colwise().crbegin(), A.colwise().crend(), [&col](typename ColMatrixType::ConstColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) ); i = internal::random<Index>(0,A.rows()-1); A.setRandom(); A.row(i).setZero(); - VERIFY_IS_EQUAL( std::find_if(A.rowwise().begin(), A.rowwise().end(), [](typename ColMatrixType::RowXpr x) { return x.squaredNorm() == Scalar(0); })-A.rowwise().begin(), i ); + VERIFY_IS_EQUAL( std::find_if(A.rowwise().begin(), A.rowwise().end(), [](typename ColMatrixType::RowXpr x) { return x.squaredNorm() == Scalar(0); })-A.rowwise().begin(), i ); + VERIFY_IS_EQUAL( std::find_if(A.rowwise().rbegin(), A.rowwise().rend(), [](typename ColMatrixType::RowXpr x) { return x.squaredNorm() == Scalar(0); })-A.rowwise().rbegin(), (A.rows()-1) - i ); j = internal::random<Index>(0,A.cols()-1); A.setRandom(); A.col(j).setZero(); - VERIFY_IS_EQUAL( std::find_if(A.colwise().begin(), A.colwise().end(), [](typename ColMatrixType::ColXpr x) { return x.squaredNorm() == Scalar(0); })-A.colwise().begin(), j ); + VERIFY_IS_EQUAL( std::find_if(A.colwise().begin(), A.colwise().end(), [](typename ColMatrixType::ColXpr x) { return x.squaredNorm() == Scalar(0); })-A.colwise().begin(), j ); + VERIFY_IS_EQUAL( std::find_if(A.colwise().rbegin(), A.colwise().rend(), [](typename ColMatrixType::ColXpr x) { return x.squaredNorm() == Scalar(0); })-A.colwise().rbegin(), (A.cols()-1) - j ); } { |