diff options
author | Gael Guennebaud <g.gael@free.fr> | 2018-10-02 13:29:32 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2018-10-02 13:29:32 +0200 |
commit | 37e29fc89389ff1514315b1cf96a8253e0b5c69d (patch) | |
tree | 50dcb91253dbb8fd309f69ebd3d0fb7d55d61148 | |
parent | b0c66adfb1c72d060ec98ebf1004a73b6e4cd559 (diff) |
Use Index instead of ptrdiff_t or int, fix random-accessors.
-rw-r--r-- | Eigen/src/Core/StlIterators.h | 35 | ||||
-rw-r--r-- | test/stl_iterators.cpp | 30 |
2 files changed, 44 insertions, 21 deletions
diff --git a/Eigen/src/Core/StlIterators.h b/Eigen/src/Core/StlIterators.h index c2b162a7b..b4c618db2 100644 --- a/Eigen/src/Core/StlIterators.h +++ b/Eigen/src/Core/StlIterators.h @@ -13,7 +13,7 @@ template<typename XprType,typename Derived> class DenseStlIteratorBase { public: - typedef std::ptrdiff_t difference_type; + typedef Index difference_type; typedef std::random_access_iterator_tag iterator_category; DenseStlIteratorBase() : mp_xpr(0), m_index(0) {} @@ -30,13 +30,13 @@ public: Derived operator++(int) { Derived prev(derived()); operator++(); return prev;} Derived operator--(int) { Derived prev(derived()); operator--(); return prev;} - friend Derived operator+(const DenseStlIteratorBase& a, int b) { Derived ret(a.derived()); ret += b; return ret; } - friend Derived operator-(const DenseStlIteratorBase& a, int b) { Derived ret(a.derived()); ret -= b; return ret; } - friend Derived operator+(int a, const DenseStlIteratorBase& b) { Derived ret(b.derived()); ret += a; return ret; } - friend Derived operator-(int a, const DenseStlIteratorBase& b) { Derived ret(b.derived()); ret -= a; return ret; } + friend Derived operator+(const DenseStlIteratorBase& a, Index b) { Derived ret(a.derived()); ret += b; return ret; } + friend Derived operator-(const DenseStlIteratorBase& a, Index b) { Derived ret(a.derived()); ret -= b; return ret; } + friend Derived operator+(Index a, const DenseStlIteratorBase& b) { Derived ret(b.derived()); ret += a; return ret; } + friend Derived operator-(Index a, const DenseStlIteratorBase& b) { Derived ret(b.derived()); ret -= a; return ret; } - Derived& operator+=(int b) { m_index += b; return derived(); } - Derived& operator-=(int b) { m_index -= b; return derived(); } + Derived& operator+=(Index b) { m_index += b; return derived(); } + Derived& operator-=(Index b) { m_index -= b; return derived(); } difference_type operator-(const DenseStlIteratorBase& other) const { eigen_assert(mp_xpr == other.mp_xpr);return m_index - other.m_index; } @@ -84,10 +84,9 @@ public: DenseStlIterator() : Base() {} DenseStlIterator(XprType& xpr, Index index) : Base(xpr,index) {} - reference operator*() const { return (*mp_xpr)(m_index); } - reference operator[](int i) const { return (*mp_xpr)(i); } - - pointer operator->() const { return &((*mp_xpr)(m_index)); } + reference operator*() const { return (*mp_xpr)(m_index); } + reference operator[](Index i) const { return (*mp_xpr)(m_index+i); } + pointer operator->() const { return &((*mp_xpr)(m_index)); } }; template<typename XprType,typename Derived> @@ -154,10 +153,9 @@ public: DenseColStlIterator() : Base() {} DenseColStlIterator(XprType& xpr, Index index) : Base(xpr,index) {} - reference operator*() const { return (*mp_xpr).col(m_index); } - reference operator[](int i) const { return (*mp_xpr).col(i); } - - pointer operator->() const { return &((*mp_xpr).col(m_index)); } + reference operator*() const { return (*mp_xpr).col(m_index); } + reference operator[](Index i) const { return (*mp_xpr).col(m_index+i); } + pointer operator->() const { return &((*mp_xpr).col(m_index)); } }; template<typename XprType> @@ -179,10 +177,9 @@ public: DenseRowStlIterator() : Base() {} DenseRowStlIterator(XprType& xpr, Index index) : Base(xpr,index) {} - reference operator*() const { return (*mp_xpr).row(m_index); } - reference operator[](int i) const { return (*mp_xpr).row(i); } - - pointer operator->() const { return &((*mp_xpr).row(m_index)); } + reference operator*() const { return (*mp_xpr).row(m_index); } + reference operator[](Index i) const { return (*mp_xpr).row(m_index+i); } + pointer operator->() const { return &((*mp_xpr).row(m_index)); } }; diff --git a/test/stl_iterators.cpp b/test/stl_iterators.cpp index 1ed52b354..f56209f07 100644 --- a/test/stl_iterators.cpp +++ b/test/stl_iterators.cpp @@ -59,6 +59,16 @@ void test_range_for_loop(int rows=Rows, int cols=Cols) VERIFY_IS_EQUAL(v,w); #endif + if(rows>=3) { + VERIFY_IS_EQUAL((v.begin()+rows/2)[1], v(rows/2+1)); + + VERIFY_IS_EQUAL((A.allRows().begin()+rows/2)[1], A.row(rows/2+1)); + } + + if(cols>=3) { + VERIFY_IS_EQUAL((A.allCols().begin()+cols/2)[1], A.col(cols/2+1)); + } + if(rows>=2) { v(1) = v(0)-Scalar(1); @@ -84,11 +94,27 @@ void test_range_for_loop(int rows=Rows, int cols=Cols) j = internal::random<Index>(0,A.cols()-1); typename ColMatrixType::ColXpr Acol = A.col(j); std::partial_sum(begin(Acol), end(Acol), begin(v)); - VERIFY_IS_APPROX(v(seq(1,last)), v(seq(0,last-1))+Acol(seq(1,last))); + VERIFY_IS_EQUAL(v(seq(1,last)), v(seq(0,last-1))+Acol(seq(1,last))); // inplace std::partial_sum(begin(Acol), end(Acol), begin(Acol)); - VERIFY_IS_APPROX(v, Acol); + VERIFY_IS_EQUAL(v, Acol); + } + + if(rows>=3) + { + // stress random access + v.setRandom(); + VectorType v1 = v; + std::sort(begin(v1),end(v1)); + std::nth_element(v.begin(), v.begin()+rows/2, v.end()); + VERIFY_IS_APPROX(v1(rows/2), v(rows/2)); + + v.setRandom(); + v1 = v; + std::sort(begin(v1)+rows/2,end(v1)); + std::nth_element(v.begin()+rows/2, v.begin()+rows/4, v.end()); + VERIFY_IS_APPROX(v1(rows/4), v(rows/4)); } #if EIGEN_HAS_CXX11 |