diff options
author | Gael Guennebaud <g.gael@free.fr> | 2018-10-01 23:21:37 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2018-10-01 23:21:37 +0200 |
commit | b0c66adfb1c72d060ec98ebf1004a73b6e4cd559 (patch) | |
tree | 1fa3c3cbe2247a5d87255c1c36a41588d80a3dfd | |
parent | 2088c0897f6ea7175d06de98fe04c71cd453a34d (diff) |
bug #231: initial implementation of STL iterators for dense expressions
-rw-r--r-- | Eigen/Core | 1 | ||||
-rw-r--r-- | Eigen/src/Core/DenseBase.h | 11 | ||||
-rw-r--r-- | Eigen/src/Core/StlIterators.h | 235 | ||||
-rw-r--r-- | Eigen/src/Core/util/ForwardDeclarations.h | 3 | ||||
-rw-r--r-- | test/CMakeLists.txt | 1 | ||||
-rw-r--r-- | test/stl_iterators.cpp | 128 |
6 files changed, 379 insertions, 0 deletions
diff --git a/Eigen/Core b/Eigen/Core index 7347a2480..6fd32dd82 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -310,6 +310,7 @@ using std::ptrdiff_t; #include "src/Core/Replicate.h" #include "src/Core/Reverse.h" #include "src/Core/ArrayWrapper.h" +#include "src/Core/StlIterators.h" #ifdef EIGEN_USE_BLAS #include "src/Core/products/GeneralMatrixMatrix_BLAS.h" diff --git a/Eigen/src/Core/DenseBase.h b/Eigen/src/Core/DenseBase.h index 0c0ea95f4..93410670f 100644 --- a/Eigen/src/Core/DenseBase.h +++ b/Eigen/src/Core/DenseBase.h @@ -572,6 +572,17 @@ template<typename Derived> class DenseBase } EIGEN_DEVICE_FUNC void reverseInPlace(); + inline DenseStlIterator<Derived> begin(); + inline DenseStlIterator<const Derived> begin() const; + inline DenseStlIterator<const Derived> cbegin() const; + inline DenseStlIterator<Derived> end(); + inline DenseStlIterator<const Derived> end() const; + inline DenseStlIterator<const Derived> cend() const; + inline ColsProxy<Derived> allCols(); + inline ColsProxy<const Derived> allCols() const; + inline RowsProxy<Derived> allRows(); + inline RowsProxy<const Derived> allRows() const; + #define EIGEN_CURRENT_STORAGE_BASE_CLASS Eigen::DenseBase #define EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL #define EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(COND) diff --git a/Eigen/src/Core/StlIterators.h b/Eigen/src/Core/StlIterators.h new file mode 100644 index 000000000..c2b162a7b --- /dev/null +++ b/Eigen/src/Core/StlIterators.h @@ -0,0 +1,235 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2018 Gael Guennebaud <gael.guennebaud@inria.fr> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +namespace Eigen { + +template<typename XprType,typename Derived> +class DenseStlIteratorBase +{ +public: + typedef std::ptrdiff_t difference_type; + typedef std::random_access_iterator_tag iterator_category; + + DenseStlIteratorBase() : mp_xpr(0), m_index(0) {} + DenseStlIteratorBase(XprType& xpr, Index index) : mp_xpr(&xpr), m_index(index) {} + + void swap(DenseStlIteratorBase& other) { + std::swap(mp_xpr,other.mp_xpr); + std::swap(m_index,other.m_index); + } + + Derived& operator++() { ++m_index; return derived(); } + Derived& operator--() { --m_index; return derived(); } + + Derived operator++(int) { Derived prev(derived()); operator++(); return prev;} + Derived operator--(int) { Derived prev(derived()); operator--(); return prev;} + + friend Derived operator+(const DenseStlIteratorBase& a, int b) { Derived ret(a.derived()); ret += b; return ret; } + friend Derived operator-(const DenseStlIteratorBase& a, int b) { Derived ret(a.derived()); ret -= b; return ret; } + friend Derived operator+(int a, const DenseStlIteratorBase& b) { Derived ret(b.derived()); ret += a; return ret; } + friend Derived operator-(int a, const DenseStlIteratorBase& b) { Derived ret(b.derived()); ret -= a; return ret; } + + Derived& operator+=(int b) { m_index += b; return derived(); } + Derived& operator-=(int b) { m_index -= b; return derived(); } + + difference_type operator-(const DenseStlIteratorBase& other) const { eigen_assert(mp_xpr == other.mp_xpr);return m_index - other.m_index; } + + bool operator==(const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; } + bool operator!=(const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; } + bool operator< (const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; } + bool operator<=(const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; } + bool operator> (const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; } + bool operator>=(const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; } + +protected: + + Derived& derived() { return static_cast<Derived&>(*this); } + const Derived& derived() const { return static_cast<const Derived&>(*this); } + + XprType *mp_xpr; + Index m_index; +}; + +template<typename XprType> +class DenseStlIterator : public DenseStlIteratorBase<XprType, DenseStlIterator<XprType> > +{ +public: + typedef typename XprType::Scalar value_type; + +protected: + + enum { + has_direct_access = (internal::traits<XprType>::Flags & DirectAccessBit) ? 1 : 0, + has_write_access = internal::is_lvalue<XprType>::value + }; + + typedef DenseStlIteratorBase<XprType,DenseStlIterator> Base; + using Base::m_index; + using Base::mp_xpr; + + typedef typename internal::conditional<bool(has_direct_access), const value_type&, const value_type>::type read_only_ref_t; + +public: + + typedef typename internal::conditional<bool(has_write_access), value_type *, const value_type *>::type pointer; + typedef typename internal::conditional<bool(has_write_access), value_type&, read_only_ref_t>::type reference; + + + DenseStlIterator() : Base() {} + DenseStlIterator(XprType& xpr, Index index) : Base(xpr,index) {} + + reference operator*() const { return (*mp_xpr)(m_index); } + reference operator[](int i) const { return (*mp_xpr)(i); } + + pointer operator->() const { return &((*mp_xpr)(m_index)); } +}; + +template<typename XprType,typename Derived> +void swap(DenseStlIteratorBase<XprType,Derived>& a, DenseStlIteratorBase<XprType,Derived>& b) { + a.swap(b); +} + +template<typename Derived> +inline DenseStlIterator<Derived> DenseBase<Derived>::begin() +{ + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived); + return DenseStlIterator<Derived>(derived(), 0); +} + +template<typename Derived> +inline DenseStlIterator<const Derived> DenseBase<Derived>::begin() const +{ + return cbegin(); +} + +template<typename Derived> +inline DenseStlIterator<const Derived> DenseBase<Derived>::cbegin() const +{ + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived); + return DenseStlIterator<const Derived>(derived(), 0); +} + +template<typename Derived> +inline DenseStlIterator<Derived> DenseBase<Derived>::end() +{ + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived); + return DenseStlIterator<Derived>(derived(), size()); +} + +template<typename Derived> +inline DenseStlIterator<const Derived> DenseBase<Derived>::end() const +{ + return cend(); +} + +template<typename Derived> +inline DenseStlIterator<const Derived> DenseBase<Derived>::cend() const +{ + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived); + return DenseStlIterator<const Derived>(derived(), size()); +} + +template<typename XprType> +class DenseColStlIterator : public DenseStlIteratorBase<XprType, DenseColStlIterator<XprType> > +{ +protected: + + enum { is_lvalue = internal::is_lvalue<XprType>::value }; + + typedef DenseStlIteratorBase<XprType,DenseColStlIterator> Base; + using Base::m_index; + using Base::mp_xpr; + +public: + typedef typename internal::conditional<bool(is_lvalue), typename XprType::ColXpr, typename XprType::ConstColXpr>::type value_type; + typedef value_type* pointer; + typedef value_type reference; + + DenseColStlIterator() : Base() {} + DenseColStlIterator(XprType& xpr, Index index) : Base(xpr,index) {} + + reference operator*() const { return (*mp_xpr).col(m_index); } + reference operator[](int i) const { return (*mp_xpr).col(i); } + + pointer operator->() const { return &((*mp_xpr).col(m_index)); } +}; + +template<typename XprType> +class DenseRowStlIterator : public DenseStlIteratorBase<XprType, DenseRowStlIterator<XprType> > +{ +protected: + + enum { is_lvalue = internal::is_lvalue<XprType>::value }; + + typedef DenseStlIteratorBase<XprType,DenseRowStlIterator> Base; + using Base::m_index; + using Base::mp_xpr; + +public: + typedef typename internal::conditional<bool(is_lvalue), typename XprType::RowXpr, typename XprType::ConstRowXpr>::type value_type; + typedef value_type* pointer; + typedef value_type reference; + + DenseRowStlIterator() : Base() {} + DenseRowStlIterator(XprType& xpr, Index index) : Base(xpr,index) {} + + reference operator*() const { return (*mp_xpr).row(m_index); } + reference operator[](int i) const { return (*mp_xpr).row(i); } + + pointer operator->() const { return &((*mp_xpr).row(m_index)); } +}; + + +template<typename Xpr> +class ColsProxy +{ +public: + ColsProxy(Xpr& xpr) : m_xpr(xpr) {} + DenseColStlIterator<Xpr> begin() const { return DenseColStlIterator<Xpr>(m_xpr, 0); } + DenseColStlIterator<const Xpr> cbegin() const { return DenseColStlIterator<const Xpr>(m_xpr, 0); } + + DenseColStlIterator<Xpr> end() const { return DenseColStlIterator<Xpr>(m_xpr, m_xpr.cols()); } + DenseColStlIterator<const Xpr> cend() const { return DenseColStlIterator<const Xpr>(m_xpr, m_xpr.cols()); } + +protected: + Xpr& m_xpr; +}; + +template<typename Xpr> +class RowsProxy +{ +public: + RowsProxy(Xpr& xpr) : m_xpr(xpr) {} + DenseRowStlIterator<Xpr> begin() const { return DenseRowStlIterator<Xpr>(m_xpr, 0); } + DenseRowStlIterator<const Xpr> cbegin() const { return DenseRowStlIterator<const Xpr>(m_xpr, 0); } + + DenseRowStlIterator<Xpr> end() const { return DenseRowStlIterator<Xpr>(m_xpr, m_xpr.rows()); } + DenseRowStlIterator<const Xpr> cend() const { return DenseRowStlIterator<const Xpr>(m_xpr, m_xpr.rows()); } + +protected: + Xpr& m_xpr; +}; + +template<typename Derived> +ColsProxy<Derived> DenseBase<Derived>::allCols() +{ return ColsProxy<Derived>(derived()); } + +template<typename Derived> +ColsProxy<const Derived> DenseBase<Derived>::allCols() const +{ return ColsProxy<const Derived>(derived()); } + +template<typename Derived> +RowsProxy<Derived> DenseBase<Derived>::allRows() +{ return RowsProxy<Derived>(derived()); } + +template<typename Derived> +RowsProxy<const Derived> DenseBase<Derived>::allRows() const +{ return RowsProxy<const Derived>(derived()); } + +} // namespace Eigen diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index fca8a350e..d2532d854 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -133,6 +133,9 @@ template<typename ExpressionType> class ArrayWrapper; template<typename ExpressionType> class MatrixWrapper; template<typename Derived> class SolverBase; template<typename XprType> class InnerIterator; +template<typename XprType> class DenseStlIterator; +template<typename XprType> class ColsProxy; +template<typename XprType> class RowsProxy; namespace internal { template<typename DecompositionType> struct kernel_retval_base; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 45e7abbd1..f215d97cd 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -285,6 +285,7 @@ ei_add_test(inplace_decomposition) ei_add_test(half_float) ei_add_test(array_of_string) ei_add_test(num_dimensions) +ei_add_test(stl_iterators) add_executable(bug1213 bug1213.cpp bug1213_main.cpp) diff --git a/test/stl_iterators.cpp b/test/stl_iterators.cpp new file mode 100644 index 000000000..1ed52b354 --- /dev/null +++ b/test/stl_iterators.cpp @@ -0,0 +1,128 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2018 Gael Guennebaud <gael.guennebaud@inria.fr> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +template< class Iterator > +std::reverse_iterator<Iterator> +make_reverse_iterator( Iterator i ) +{ + return std::reverse_iterator<Iterator>(i); +} + +template<typename Scalar, int Rows, int Cols> +void test_range_for_loop(int rows=Rows, int cols=Cols) +{ + using std::begin; + using std::end; + + typedef Matrix<Scalar,Rows,1> VectorType; + typedef Matrix<Scalar,Rows,Cols,ColMajor> ColMatrixType; + typedef Matrix<Scalar,Rows,Cols,RowMajor> RowMatrixType; + VectorType v = VectorType::Random(rows); + ColMatrixType A = ColMatrixType::Random(rows,cols); + RowMatrixType B = RowMatrixType::Random(rows,cols); + + Index i, j; + +#if EIGEN_HAS_CXX11 + i = 0; + for(auto x : v) { VERIFY_IS_EQUAL(x,v[i++]); } + + j = internal::random<Index>(0,A.cols()-1); + i = 0; + for(auto x : A.col(j)) { VERIFY_IS_EQUAL(x,A(i++,j)); } + + i = 0; + for(auto x : (v+A.col(j))) { VERIFY_IS_APPROX(x,v(i)+A(i,j)); ++i; } + + j = 0; + i = internal::random<Index>(0,A.rows()-1); + for(auto x : A.row(i)) { VERIFY_IS_EQUAL(x,A(i,j++)); } + + i = 0; + for(auto x : A.reshaped()) { VERIFY_IS_EQUAL(x,A(i++)); } + + Matrix<Scalar,Dynamic,Dynamic,ColMajor> Bc = B; + i = 0; + for(auto x : B.reshaped()) { VERIFY_IS_EQUAL(x,Bc(i++)); } + + VectorType w(v.size()); + i = 0; + for(auto& x : w) { x = v(i++); } + VERIFY_IS_EQUAL(v,w); +#endif + + if(rows>=2) + { + v(1) = v(0)-Scalar(1); + VERIFY(!std::is_sorted(begin(v),end(v))); + } + std::sort(begin(v),end(v)); + VERIFY(std::is_sorted(begin(v),end(v))); + VERIFY(!std::is_sorted(make_reverse_iterator(end(v)),make_reverse_iterator(begin(v)))); + + { + j = internal::random<Index>(0,A.cols()-1); + // std::sort(begin(A.col(j)),end(A.col(j))); // does not compile because this returns const iterators + typename ColMatrixType::ColXpr Acol = A.col(j); + std::sort(begin(Acol),end(Acol)); + VERIFY(std::is_sorted(Acol.cbegin(),Acol.cend())); + + // This raises an assert because this creates a pair of iterator referencing two different proxy objects: + // std::sort(A.col(j).begin(),A.col(j).end()); + // VERIFY(std::is_sorted(A.col(j).cbegin(),A.col(j).cend())); // same issue + } + + { + j = internal::random<Index>(0,A.cols()-1); + typename ColMatrixType::ColXpr Acol = A.col(j); + std::partial_sum(begin(Acol), end(Acol), begin(v)); + VERIFY_IS_APPROX(v(seq(1,last)), v(seq(0,last-1))+Acol(seq(1,last))); + + // inplace + std::partial_sum(begin(Acol), end(Acol), begin(Acol)); + VERIFY_IS_APPROX(v, Acol); + } + +#if EIGEN_HAS_CXX11 + j = 0; + for(auto c : A.allCols()) { VERIFY_IS_APPROX(c.sum(), A.col(j).sum()); ++j; } + j = 0; + for(auto c : B.allCols()) { VERIFY_IS_APPROX(c.sum(), B.col(j).sum()); ++j; } + + j = 0; + for(auto c : B.allCols()) { + i = 0; + for(auto& x : c) { + VERIFY_IS_EQUAL(x, B(i,j)); + x = A(i,j); + ++i; + } + ++j; + } + VERIFY_IS_APPROX(A,B); + B = Bc; // restore B + + i = 0; + for(auto r : A.allRows()) { VERIFY_IS_APPROX(r.sum(), A.row(i).sum()); ++i; } + i = 0; + for(auto r : B.allRows()) { VERIFY_IS_APPROX(r.sum(), B.row(i).sum()); ++i; } + +#endif +} + +EIGEN_DECLARE_TEST(stl_iterators) +{ + for(int i = 0; i < g_repeat; i++) { + CALL_SUBTEST_1(( test_range_for_loop<double,2,3>() )); + CALL_SUBTEST_1(( test_range_for_loop<float,7,5>() )); + CALL_SUBTEST_1(( test_range_for_loop<int,Dynamic,Dynamic>(internal::random<int>(10,200), internal::random<int>(10,200)) )); + } +} |