aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Felipe Attanasio <oraculoide@gmail.com>2020-05-14 22:38:20 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-05-14 22:38:20 +0000
commitd640276d31a7dea9207a68a061a6fa7c9fdf50e5 (patch)
treea25c77d7c781b4876313c854a9f677a9597eddf5
parentfa8fd4b4d57323384644394c651ca106d299695f (diff)
Added support for reverse iterators for Vectorwise operations.
-rw-r--r--Eigen/src/Core/StlIterators.h129
-rw-r--r--Eigen/src/Core/VectorwiseOp.h36
-rw-r--r--Eigen/src/Core/util/ForwardDeclarations.h1
-rw-r--r--test/stl_iterators.cpp15
4 files changed, 167 insertions, 14 deletions
diff --git a/Eigen/src/Core/StlIterators.h b/Eigen/src/Core/StlIterators.h
index 0d8bd1aa3..8584e0e03 100644
--- a/Eigen/src/Core/StlIterators.h
+++ b/Eigen/src/Core/StlIterators.h
@@ -93,6 +93,85 @@ protected:
Index m_index;
};
+template<typename Derived>
+class indexed_based_stl_reverse_iterator_base
+{
+protected:
+ typedef indexed_based_stl_iterator_traits<Derived> traits;
+ typedef typename traits::XprType XprType;
+ typedef indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator> non_const_iterator;
+ typedef indexed_based_stl_reverse_iterator_base<typename traits::const_iterator> const_iterator;
+ typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
+ // NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
+ friend class indexed_based_stl_reverse_iterator_base<typename traits::const_iterator>;
+ friend class indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator>;
+public:
+ typedef Index difference_type;
+ typedef std::random_access_iterator_tag iterator_category;
+
+ indexed_based_stl_reverse_iterator_base() : mp_xpr(0), m_index(0) {}
+ indexed_based_stl_reverse_iterator_base(XprType& xpr, Index index) : mp_xpr(&xpr), m_index(index) {}
+
+ indexed_based_stl_reverse_iterator_base(const non_const_iterator& other)
+ : mp_xpr(other.mp_xpr), m_index(other.m_index)
+ {}
+
+ indexed_based_stl_reverse_iterator_base& operator=(const non_const_iterator& other)
+ {
+ mp_xpr = other.mp_xpr;
+ m_index = other.m_index;
+ return *this;
+ }
+
+ Derived& operator++() { --m_index; return derived(); }
+ Derived& operator--() { ++m_index; return derived(); }
+
+ Derived operator++(int) { Derived prev(derived()); operator++(); return prev;}
+ Derived operator--(int) { Derived prev(derived()); operator--(); return prev;}
+
+ friend Derived operator+(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret += b; return ret; }
+ friend Derived operator-(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret -= b; return ret; }
+ friend Derived operator+(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret += a; return ret; }
+ friend Derived operator-(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret -= a; return ret; }
+
+ Derived& operator+=(Index b) { m_index -= b; return derived(); }
+ Derived& operator-=(Index b) { m_index += b; return derived(); }
+
+ difference_type operator-(const indexed_based_stl_reverse_iterator_base& other) const
+ {
+ eigen_assert(mp_xpr == other.mp_xpr);
+ return other.m_index - m_index;
+ }
+
+ difference_type operator-(const other_iterator& other) const
+ {
+ eigen_assert(mp_xpr == other.mp_xpr);
+ return other.m_index - m_index;
+ }
+
+ bool operator==(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
+ bool operator!=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
+ bool operator< (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
+ bool operator<=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
+ bool operator> (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
+ bool operator>=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
+
+ bool operator==(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
+ bool operator!=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
+ bool operator< (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
+ bool operator<=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
+ bool operator> (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
+ bool operator>=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
+
+protected:
+
+ Derived& derived() { return static_cast<Derived&>(*this); }
+ const Derived& derived() const { return static_cast<const Derived&>(*this); }
+
+ XprType *mp_xpr;
+ Index m_index;
+};
+
template<typename XprType>
class pointer_based_stl_iterator
{
@@ -267,6 +346,54 @@ public:
pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); }
};
+template<typename _XprType, DirectionType Direction>
+struct indexed_based_stl_iterator_traits<subvector_stl_reverse_iterator<_XprType,Direction> >
+{
+ typedef _XprType XprType;
+ typedef subvector_stl_reverse_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator;
+ typedef subvector_stl_reverse_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator;
+};
+
+template<typename XprType, DirectionType Direction>
+class subvector_stl_reverse_iterator : public indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator<XprType,Direction> >
+{
+protected:
+
+ enum { is_lvalue = internal::is_lvalue<XprType>::value };
+
+ typedef indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator> Base;
+ using Base::m_index;
+ using Base::mp_xpr;
+
+ typedef typename internal::conditional<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr>::type SubVectorType;
+ typedef typename internal::conditional<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr>::type ConstSubVectorType;
+
+
+public:
+ typedef typename internal::conditional<bool(is_lvalue), SubVectorType, ConstSubVectorType>::type reference;
+ typedef typename reference::PlainObject value_type;
+
+private:
+ class subvector_stl_reverse_iterator_ptr
+ {
+ public:
+ subvector_stl_reverse_iterator_ptr(const reference &subvector) : m_subvector(subvector) {}
+ reference* operator->() { return &m_subvector; }
+ private:
+ reference m_subvector;
+ };
+public:
+
+ typedef subvector_stl_reverse_iterator_ptr pointer;
+
+ subvector_stl_reverse_iterator() : Base() {}
+ subvector_stl_reverse_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
+
+ reference operator*() const { return (*mp_xpr).template subVector<Direction>(m_index); }
+ reference operator[](Index i) const { return (*mp_xpr).template subVector<Direction>(m_index+i); }
+ pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); }
+};
+
} // namespace internal
@@ -328,4 +455,4 @@ inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::cend() co
return const_iterator(derived(), size());
}
-} // namespace Eigen
+} // namespace Eigen \ No newline at end of file
diff --git a/Eigen/src/Core/VectorwiseOp.h b/Eigen/src/Core/VectorwiseOp.h
index 865691b32..91a6c0353 100644
--- a/Eigen/src/Core/VectorwiseOp.h
+++ b/Eigen/src/Core/VectorwiseOp.h
@@ -279,27 +279,47 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
/** This is the const version of iterator (aka read-only) */
random_access_iterator_type const_iterator;
#else
- typedef internal::subvector_stl_iterator<ExpressionType, DirectionType(Direction)> iterator;
- typedef internal::subvector_stl_iterator<const ExpressionType, DirectionType(Direction)> const_iterator;
+ typedef internal::subvector_stl_iterator<ExpressionType, DirectionType(Direction)> iterator;
+ typedef internal::subvector_stl_iterator<const ExpressionType, DirectionType(Direction)> const_iterator;
+ typedef internal::subvector_stl_reverse_iterator<ExpressionType, DirectionType(Direction)> reverse_iterator;
+ typedef internal::subvector_stl_reverse_iterator<const ExpressionType, DirectionType(Direction)> const_reverse_iterator;
#endif
/** returns an iterator to the first row (rowwise) or column (colwise) of the nested expression.
* \sa end(), cbegin()
*/
- iterator begin() { return iterator (m_matrix, 0); }
+ iterator begin() { return iterator (m_matrix, 0); }
/** const version of begin() */
- const_iterator begin() const { return const_iterator(m_matrix, 0); }
+ const_iterator begin() const { return const_iterator(m_matrix, 0); }
/** const version of begin() */
- const_iterator cbegin() const { return const_iterator(m_matrix, 0); }
+ const_iterator cbegin() const { return const_iterator(m_matrix, 0); }
+
+ /** returns a reverse iterator to the last row (rowwise) or column (colwise) of the nested expression.
+ * \sa rend(), crbegin()
+ */
+ reverse_iterator rbegin() { return reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); }
+ /** const version of rbegin() */
+ const_reverse_iterator rbegin() const { return const_reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); }
+ /** const version of rbegin() */
+ const_reverse_iterator crbegin() const { return const_reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); }
/** returns an iterator to the row (resp. column) following the last row (resp. column) of the nested expression
* \sa begin(), cend()
*/
- iterator end() { return iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
+ iterator end() { return iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
/** const version of end() */
- const_iterator end() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
+ const_iterator end() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
/** const version of end() */
- const_iterator cend() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
+ const_iterator cend() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
+
+ /** returns a reverse iterator to the row (resp. column) before the first row (resp. column) of the nested expression
+ * \sa begin(), cend()
+ */
+ reverse_iterator rend() { return reverse_iterator (m_matrix, -1); }
+ /** const version of rend() */
+ const_reverse_iterator rend() const { return const_reverse_iterator (m_matrix, -1); }
+ /** const version of rend() */
+ const_reverse_iterator crend() const { return const_reverse_iterator (m_matrix, -1); }
/** \returns a row or column vector expression of \c *this reduxed by \a func
*
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h
index cd0bdb5a7..031a8ec3c 100644
--- a/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -134,6 +134,7 @@ namespace internal {
template<typename XprType> class generic_randaccess_stl_iterator;
template<typename XprType> class pointer_based_stl_iterator;
template<typename XprType, DirectionType Direction> class subvector_stl_iterator;
+template<typename XprType, DirectionType Direction> class subvector_stl_reverse_iterator;
template<typename DecompositionType> struct kernel_retval_base;
template<typename DecompositionType> struct kernel_retval;
template<typename DecompositionType> struct image_retval_base;
diff --git a/test/stl_iterators.cpp b/test/stl_iterators.cpp
index 25468eb49..997f8016c 100644
--- a/test/stl_iterators.cpp
+++ b/test/stl_iterators.cpp
@@ -431,22 +431,27 @@ void test_stl_iterators(int rows=Rows, int cols=Cols)
{
RowVectorType row = RowVectorType::Random(cols);
A.rowwise() = row;
- VERIFY( std::all_of(A.rowwise().begin(), A.rowwise().end(), [&row](typename ColMatrixType::RowXpr x) { return internal::isApprox(x.squaredNorm(),row.squaredNorm()); }) );
+ VERIFY( std::all_of(A.rowwise().begin(), A.rowwise().end(), [&row](typename ColMatrixType::RowXpr x) { return internal::isApprox(x.squaredNorm(),row.squaredNorm()); }) );
+ VERIFY( std::all_of(A.rowwise().rbegin(), A.rowwise().rend(), [&row](typename ColMatrixType::RowXpr x) { return internal::isApprox(x.squaredNorm(),row.squaredNorm()); }) );
VectorType col = VectorType::Random(rows);
A.colwise() = col;
- VERIFY( std::all_of(A.colwise().begin(), A.colwise().end(), [&col](typename ColMatrixType::ColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
- VERIFY( std::all_of(A.colwise().cbegin(), A.colwise().cend(), [&col](typename ColMatrixType::ConstColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
+ VERIFY( std::all_of(A.colwise().begin(), A.colwise().end(), [&col](typename ColMatrixType::ColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
+ VERIFY( std::all_of(A.colwise().rbegin(), A.colwise().rend(), [&col](typename ColMatrixType::ColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
+ VERIFY( std::all_of(A.colwise().cbegin(), A.colwise().cend(), [&col](typename ColMatrixType::ConstColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
+ VERIFY( std::all_of(A.colwise().crbegin(), A.colwise().crend(), [&col](typename ColMatrixType::ConstColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
i = internal::random<Index>(0,A.rows()-1);
A.setRandom();
A.row(i).setZero();
- VERIFY_IS_EQUAL( std::find_if(A.rowwise().begin(), A.rowwise().end(), [](typename ColMatrixType::RowXpr x) { return x.squaredNorm() == Scalar(0); })-A.rowwise().begin(), i );
+ VERIFY_IS_EQUAL( std::find_if(A.rowwise().begin(), A.rowwise().end(), [](typename ColMatrixType::RowXpr x) { return x.squaredNorm() == Scalar(0); })-A.rowwise().begin(), i );
+ VERIFY_IS_EQUAL( std::find_if(A.rowwise().rbegin(), A.rowwise().rend(), [](typename ColMatrixType::RowXpr x) { return x.squaredNorm() == Scalar(0); })-A.rowwise().rbegin(), (A.rows()-1) - i );
j = internal::random<Index>(0,A.cols()-1);
A.setRandom();
A.col(j).setZero();
- VERIFY_IS_EQUAL( std::find_if(A.colwise().begin(), A.colwise().end(), [](typename ColMatrixType::ColXpr x) { return x.squaredNorm() == Scalar(0); })-A.colwise().begin(), j );
+ VERIFY_IS_EQUAL( std::find_if(A.colwise().begin(), A.colwise().end(), [](typename ColMatrixType::ColXpr x) { return x.squaredNorm() == Scalar(0); })-A.colwise().begin(), j );
+ VERIFY_IS_EQUAL( std::find_if(A.colwise().rbegin(), A.colwise().rend(), [](typename ColMatrixType::ColXpr x) { return x.squaredNorm() == Scalar(0); })-A.colwise().rbegin(), (A.cols()-1) - j );
}
{