diff options
author | Gael Guennebaud <g.gael@free.fr> | 2017-01-09 23:42:16 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2017-01-09 23:42:16 +0100 |
commit | b50c3e967e1676f248c93c1a79e6574ae746e2fd (patch) | |
tree | 04dbf025bfdb58a51ed2a1667fcbff0d9ee24ec8 | |
parent | 68064e14fac8c72c05faaeff98c1b70e2dae6ee7 (diff) |
Add a minimalistic symbolic scalar type with expression template and make use of it to define the last placeholder and to unify the return type of seq and seqN.
-rw-r--r-- | Eigen/src/Core/ArithmeticSequence.h | 218 | ||||
-rw-r--r-- | test/indexed_view.cpp | 5 |
2 files changed, 196 insertions, 27 deletions
diff --git a/Eigen/src/Core/ArithmeticSequence.h b/Eigen/src/Core/ArithmeticSequence.h index 71301797a..9f4fe327b 100644 --- a/Eigen/src/Core/ArithmeticSequence.h +++ b/Eigen/src/Core/ArithmeticSequence.h @@ -34,7 +34,7 @@ struct last_t { int operator- (last_t) const { return 0; } int operator- (shifted_last x) const { return -x.offset; } }; -static const last_t last; +static const last_t last_legacy; struct shifted_end { @@ -52,7 +52,145 @@ struct end_t { int operator- (end_t) const { return 0; } int operator- (shifted_end x) const { return -x.offset; } }; -static const end_t end; +static const end_t end_legacy; + +// A simple wrapper around an Index to provide the eval method. +// We could also use a free-function symbolic_eval... +class symbolic_value_wrapper { +public: + symbolic_value_wrapper(Index val) : m_value(val) {} + template<typename T> + Index eval(const T&) const { return m_value; } +protected: + Index m_value; +}; + +//-------------------------------------------------------------------------------- +// minimalistic symbolic scalar type +//-------------------------------------------------------------------------------- + +template<typename Tag> class symbolic_symbol; +template<typename Arg0> class symbolic_negate; +template<typename Arg1,typename Arg2> class symbolic_add; +template<typename Arg1,typename Arg2> class symbolic_product; +template<typename Arg1,typename Arg2> class symbolic_quotient; + +template<typename Derived> +class symbolic_index_base +{ +public: + const Derived& derived() const { return *static_cast<const Derived*>(this); } + + symbolic_negate<Derived> operator-() const { return symbolic_negate<Derived>(derived()); } + + symbolic_add<Derived,symbolic_value_wrapper> operator+(Index b) const + { return symbolic_add<Derived,symbolic_value_wrapper >(derived(), b); } + symbolic_add<Derived,symbolic_value_wrapper> operator-(Index a) const + { return symbolic_add<Derived,symbolic_value_wrapper >(derived(), -a); } + symbolic_quotient<Derived,symbolic_value_wrapper> operator/(Index a) const + { return symbolic_quotient<Derived,symbolic_value_wrapper>(derived(),a); } + + friend symbolic_add<Derived,symbolic_value_wrapper> operator+(Index a, const symbolic_index_base& b) + { return symbolic_add<Derived,symbolic_value_wrapper>(b.derived(), a); } + friend symbolic_add<symbolic_negate<Derived>,symbolic_value_wrapper> operator-(Index a, const symbolic_index_base& b) + { return symbolic_add<symbolic_negate<Derived>,symbolic_value_wrapper>(-b.derived(), a); } + friend symbolic_add<symbolic_value_wrapper,Derived> operator/(Index a, const symbolic_index_base& b) + { return symbolic_add<symbolic_value_wrapper,Derived>(a,b.derived()); } + + template<typename OtherDerived> + symbolic_add<Derived,OtherDerived> operator+(const symbolic_index_base<OtherDerived> &b) const + { return symbolic_add<Derived,OtherDerived>(derived(), b.derived()); } + + template<typename OtherDerived> + symbolic_add<Derived,symbolic_negate<OtherDerived> > operator-(const symbolic_index_base<OtherDerived> &b) const + { return symbolic_add<Derived,symbolic_negate<OtherDerived> >(derived(), -b.derived()); } + + template<typename OtherDerived> + symbolic_add<Derived,OtherDerived> operator/(const symbolic_index_base<OtherDerived> &b) const + { return symbolic_quotient<Derived,OtherDerived>(derived(), b.derived()); } +}; + +template<typename T> +struct is_symbolic { + enum { value = internal::is_convertible<T,symbolic_index_base<T> >::value }; +}; + +template<typename Tag> +class symbolic_value_pair +{ +public: + symbolic_value_pair(Index val) : m_value(val) {} + Index value() const { return m_value; } +protected: + Index m_value; +}; + +template<typename Tag> +class symbolic_value : public symbolic_index_base<symbolic_value<Tag> > +{ +public: + symbolic_value() {} + + Index eval(const symbolic_value_pair<Tag> &values) const { return values.value(); } + + // TODO add a c++14 eval taking a tuple of symbolic_value_pair and getting the value with std::get<symbolic_value_pair<Tag> >... +}; + +template<typename Arg0> +class symbolic_negate : public symbolic_index_base<symbolic_negate<Arg0> > +{ +public: + symbolic_negate(const Arg0& arg0) : m_arg0(arg0) {} + + template<typename T> + Index eval(const T& values) const { return -m_arg0.eval(values); } +protected: + Arg0 m_arg0; +}; + +template<typename Arg0, typename Arg1> +class symbolic_add : public symbolic_index_base<symbolic_add<Arg0,Arg1> > +{ +public: + symbolic_add(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} + + template<typename T> + Index eval(const T& values) const { return m_arg0.eval(values) + m_arg1.eval(values); } +protected: + Arg0 m_arg0; + Arg1 m_arg1; +}; + +template<typename Arg0, typename Arg1> +class symbolic_product : public symbolic_index_base<symbolic_product<Arg0,Arg1> > +{ +public: + symbolic_product(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} + + template<typename T> + Index eval(const T& values) const { return m_arg0.eval(values) * m_arg1.eval(values); } +protected: + Arg0 m_arg0; + Arg1 m_arg1; +}; + +template<typename Arg0, typename Arg1> +class symbolic_quotient : public symbolic_index_base<symbolic_quotient<Arg0,Arg1> > +{ +public: + symbolic_quotient(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} + + template<typename T> + Index eval(const T& values) const { return m_arg0.eval(values) / m_arg1.eval(values); } +protected: + Arg0 m_arg0; + Arg1 m_arg1; +}; + +struct symb_last_tag {}; + +static const symbolic_value<symb_last_tag> last; +static const symbolic_add<symbolic_value<symb_last_tag>,symbolic_value_wrapper> end(last+1); //-------------------------------------------------------------------------------- // integral constant @@ -116,34 +254,30 @@ protected: IncrType m_incr; }; -template<typename T> struct cleanup_slice_type { typedef Index type; }; -template<> struct cleanup_slice_type<last_t> { typedef last_t type; }; -template<> struct cleanup_slice_type<shifted_last> { typedef shifted_last type; }; -template<> struct cleanup_slice_type<end_t> { typedef end_t type; }; -template<> struct cleanup_slice_type<shifted_end> { typedef shifted_end type; }; -template<int N> struct cleanup_slice_type<fix_t<N> > { typedef fix_t<N> type; }; -template<int N> struct cleanup_slice_type<fix_t<N> (*)() > { typedef fix_t<N> type; }; +template<typename T> struct cleanup_seq_type { typedef T type; }; +template<int N> struct cleanup_seq_type<fix_t<N> > { typedef fix_t<N> type; }; +template<int N> struct cleanup_seq_type<fix_t<N> (*)() > { typedef fix_t<N> type; }; template<typename FirstType,typename LastType> -ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type > -seq(FirstType f, LastType l) { - return ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type>(f,l); +ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type > +seq_legacy(FirstType f, LastType l) { + return ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type>(f,l); } template<typename FirstType,typename LastType,typename IncrType> -ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type,typename cleanup_slice_type<IncrType>::type > -seq(FirstType f, LastType l, IncrType s) { - return ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type,typename cleanup_slice_type<IncrType>::type>(f,l,typename cleanup_slice_type<IncrType>::type(s)); +ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type,typename cleanup_seq_type<IncrType>::type > +seq_legacy(FirstType f, LastType l, IncrType s) { + return ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type,typename cleanup_seq_type<IncrType>::type>(f,l,typename cleanup_seq_type<IncrType>::type(s)); } template<typename FirstType=Index,typename SizeType=Index,typename IncrType=fix_t<1> > -class ArithemeticSequenceProxyWithSize +class ArithemeticSequence { public: - ArithemeticSequenceProxyWithSize(FirstType first, SizeType size) : m_first(first), m_size(size) {} - ArithemeticSequenceProxyWithSize(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {} + ArithemeticSequence(FirstType first, SizeType size) : m_first(first), m_size(size) {} + ArithemeticSequence(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {} enum { SizeAtCompileTime = get_compile_time<SizeType>::value, @@ -165,18 +299,30 @@ protected: template<typename FirstType,typename SizeType,typename IncrType> -ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type,typename cleanup_slice_type<IncrType>::type > +ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type,typename cleanup_seq_type<IncrType>::type > seqN(FirstType first, SizeType size, IncrType incr) { - return ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type,typename cleanup_slice_type<IncrType>::type>(first,size,incr); + return ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type,typename cleanup_seq_type<IncrType>::type>(first,size,incr); } template<typename FirstType,typename SizeType> -ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type > +ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type > seqN(FirstType first, SizeType size) { - return ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type>(first,size); + return ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type>(first,size); } +template<typename FirstType,typename LastType> +auto seq(FirstType f, LastType l) -> decltype(seqN(f,(l-f+1))) +{ + return seqN(f,(l-f+1)); +} +template<typename FirstType,typename LastType, typename IncrType> +auto seq(FirstType f, LastType l, IncrType incr) + -> decltype(seqN(f,(l-f+typename cleanup_seq_type<IncrType>::type(incr))/typename cleanup_seq_type<IncrType>::type(incr),typename cleanup_seq_type<IncrType>::type(incr))) +{ + typedef typename cleanup_seq_type<IncrType>::type CleanedIncrType; + return seqN(f,(l-f+CleanedIncrType(incr))/CleanedIncrType(incr),CleanedIncrType(incr)); +} namespace internal { @@ -214,7 +360,7 @@ struct get_compile_time_incr<ArithemeticSequenceProxyWithBounds<FirstType,LastTy }; template<typename FirstType,typename SizeType,typename IncrType> -struct get_compile_time_incr<ArithemeticSequenceProxyWithSize<FirstType,SizeType,IncrType> > { +struct get_compile_time_incr<ArithemeticSequence<FirstType,SizeType,IncrType> > { enum { value = get_compile_time<IncrType,DynamicIndex>::value }; }; @@ -258,6 +404,17 @@ Index symbolic2value(shifted_last x, Index size) { return size+x.offset-1; } Index symbolic2value(end_t, Index size) { return size; } Index symbolic2value(shifted_end x, Index size) { return size+x.offset; } +template<int N> +fix_t<N> symbolic2value(fix_t<N> x, Index /*size*/) { return x; } + +template<typename Derived> +Index symbolic2value(const symbolic_index_base<Derived> &x, Index size) +{ + Index h=x.derived().eval(symbolic_value_pair<symb_last_tag>(size-1)); + return x.derived().eval(symbolic_value_pair<symb_last_tag>(size-1)); +} + + // Convert a symbolic range into a usable one (i.e., remove last/end "keywords") template<typename FirstType,typename LastType,typename IncrType> struct MakeIndexing<ArithemeticSequenceProxyWithBounds<FirstType,LastType,IncrType> > { @@ -270,14 +427,21 @@ ArithemeticSequenceProxyWithBounds<Index,Index,IncrType> make_indexing(const Ari } // Convert a symbolic span into a usable one (i.e., remove last/end "keywords") +template<typename T> +struct make_size_type { + typedef typename internal::conditional<is_symbolic<T>::value, Index, T>::type type; +}; + template<typename FirstType,typename SizeType,typename IncrType> -struct MakeIndexing<ArithemeticSequenceProxyWithSize<FirstType,SizeType,IncrType> > { - typedef ArithemeticSequenceProxyWithSize<Index,SizeType,IncrType> type; +struct MakeIndexing<ArithemeticSequence<FirstType,SizeType,IncrType> > { + typedef ArithemeticSequence<Index,typename make_size_type<SizeType>::type,IncrType> type; }; template<typename FirstType,typename SizeType,typename IncrType> -ArithemeticSequenceProxyWithSize<Index,SizeType,IncrType> make_indexing(const ArithemeticSequenceProxyWithSize<FirstType,SizeType,IncrType>& ids, Index size) { - return ArithemeticSequenceProxyWithSize<Index,SizeType,IncrType>(symbolic2value(ids.firstObject(),size),ids.sizeObject(),ids.incrObject()); +ArithemeticSequence<Index,typename make_size_type<SizeType>::type,IncrType> +make_indexing(const ArithemeticSequence<FirstType,SizeType,IncrType>& ids, Index size) { + return ArithemeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>( + symbolic2value(ids.firstObject(),size),symbolic2value(ids.sizeObject(),size),ids.incrObject()); } // Convert a symbolic 'all' into a usable range diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp index 23ad2d743..25a25499c 100644 --- a/test/indexed_view.cpp +++ b/test/indexed_view.cpp @@ -139,6 +139,11 @@ void check_indexed_view() VERIFY_IS_EQUAL( (A(eii, eii)).InnerStrideAtCompileTime, 0); VERIFY_IS_EQUAL( (A(eii, eii)).OuterStrideAtCompileTime, 0); + VERIFY_IS_APPROX( A(seq(n-1,2,-2), seqN(n-1-6,4)), A(seq(last,2,-2), seqN(last-6,4)) ); + VERIFY_IS_APPROX( A(seq(n-1-6,n-1-2), seqN(n-1-6,4)), A(seq(last-6,last-2), seqN(6+last-6-6,4)) ); + VERIFY_IS_APPROX( A(seq((n-1)/2,(n)/2+3), seqN(2,4)), A(seq(last/2,(last+1)/2+3), seqN(last+2-last,4)) ); + VERIFY_IS_APPROX( A(seq(n-2,2,-2), seqN(n-8,4)), A(seq(end-2,2,-2), seqN(end-8,4)) ); + #if EIGEN_HAS_CXX11 VERIFY( (A(all, std::array<int,4>{{1,3,2,4}})).ColsAtCompileTime == 4); |