aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2017-01-09 23:42:16 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2017-01-09 23:42:16 +0100
commitb50c3e967e1676f248c93c1a79e6574ae746e2fd (patch)
tree04dbf025bfdb58a51ed2a1667fcbff0d9ee24ec8
parent68064e14fac8c72c05faaeff98c1b70e2dae6ee7 (diff)
Add a minimalistic symbolic scalar type with expression template and make use of it to define the last placeholder and to unify the return type of seq and seqN.
-rw-r--r--Eigen/src/Core/ArithmeticSequence.h218
-rw-r--r--test/indexed_view.cpp5
2 files changed, 196 insertions, 27 deletions
diff --git a/Eigen/src/Core/ArithmeticSequence.h b/Eigen/src/Core/ArithmeticSequence.h
index 71301797a..9f4fe327b 100644
--- a/Eigen/src/Core/ArithmeticSequence.h
+++ b/Eigen/src/Core/ArithmeticSequence.h
@@ -34,7 +34,7 @@ struct last_t {
int operator- (last_t) const { return 0; }
int operator- (shifted_last x) const { return -x.offset; }
};
-static const last_t last;
+static const last_t last_legacy;
struct shifted_end {
@@ -52,7 +52,145 @@ struct end_t {
int operator- (end_t) const { return 0; }
int operator- (shifted_end x) const { return -x.offset; }
};
-static const end_t end;
+static const end_t end_legacy;
+
+// A simple wrapper around an Index to provide the eval method.
+// We could also use a free-function symbolic_eval...
+class symbolic_value_wrapper {
+public:
+ symbolic_value_wrapper(Index val) : m_value(val) {}
+ template<typename T>
+ Index eval(const T&) const { return m_value; }
+protected:
+ Index m_value;
+};
+
+//--------------------------------------------------------------------------------
+// minimalistic symbolic scalar type
+//--------------------------------------------------------------------------------
+
+template<typename Tag> class symbolic_symbol;
+template<typename Arg0> class symbolic_negate;
+template<typename Arg1,typename Arg2> class symbolic_add;
+template<typename Arg1,typename Arg2> class symbolic_product;
+template<typename Arg1,typename Arg2> class symbolic_quotient;
+
+template<typename Derived>
+class symbolic_index_base
+{
+public:
+ const Derived& derived() const { return *static_cast<const Derived*>(this); }
+
+ symbolic_negate<Derived> operator-() const { return symbolic_negate<Derived>(derived()); }
+
+ symbolic_add<Derived,symbolic_value_wrapper> operator+(Index b) const
+ { return symbolic_add<Derived,symbolic_value_wrapper >(derived(), b); }
+ symbolic_add<Derived,symbolic_value_wrapper> operator-(Index a) const
+ { return symbolic_add<Derived,symbolic_value_wrapper >(derived(), -a); }
+ symbolic_quotient<Derived,symbolic_value_wrapper> operator/(Index a) const
+ { return symbolic_quotient<Derived,symbolic_value_wrapper>(derived(),a); }
+
+ friend symbolic_add<Derived,symbolic_value_wrapper> operator+(Index a, const symbolic_index_base& b)
+ { return symbolic_add<Derived,symbolic_value_wrapper>(b.derived(), a); }
+ friend symbolic_add<symbolic_negate<Derived>,symbolic_value_wrapper> operator-(Index a, const symbolic_index_base& b)
+ { return symbolic_add<symbolic_negate<Derived>,symbolic_value_wrapper>(-b.derived(), a); }
+ friend symbolic_add<symbolic_value_wrapper,Derived> operator/(Index a, const symbolic_index_base& b)
+ { return symbolic_add<symbolic_value_wrapper,Derived>(a,b.derived()); }
+
+ template<typename OtherDerived>
+ symbolic_add<Derived,OtherDerived> operator+(const symbolic_index_base<OtherDerived> &b) const
+ { return symbolic_add<Derived,OtherDerived>(derived(), b.derived()); }
+
+ template<typename OtherDerived>
+ symbolic_add<Derived,symbolic_negate<OtherDerived> > operator-(const symbolic_index_base<OtherDerived> &b) const
+ { return symbolic_add<Derived,symbolic_negate<OtherDerived> >(derived(), -b.derived()); }
+
+ template<typename OtherDerived>
+ symbolic_add<Derived,OtherDerived> operator/(const symbolic_index_base<OtherDerived> &b) const
+ { return symbolic_quotient<Derived,OtherDerived>(derived(), b.derived()); }
+};
+
+template<typename T>
+struct is_symbolic {
+ enum { value = internal::is_convertible<T,symbolic_index_base<T> >::value };
+};
+
+template<typename Tag>
+class symbolic_value_pair
+{
+public:
+ symbolic_value_pair(Index val) : m_value(val) {}
+ Index value() const { return m_value; }
+protected:
+ Index m_value;
+};
+
+template<typename Tag>
+class symbolic_value : public symbolic_index_base<symbolic_value<Tag> >
+{
+public:
+ symbolic_value() {}
+
+ Index eval(const symbolic_value_pair<Tag> &values) const { return values.value(); }
+
+ // TODO add a c++14 eval taking a tuple of symbolic_value_pair and getting the value with std::get<symbolic_value_pair<Tag> >...
+};
+
+template<typename Arg0>
+class symbolic_negate : public symbolic_index_base<symbolic_negate<Arg0> >
+{
+public:
+ symbolic_negate(const Arg0& arg0) : m_arg0(arg0) {}
+
+ template<typename T>
+ Index eval(const T& values) const { return -m_arg0.eval(values); }
+protected:
+ Arg0 m_arg0;
+};
+
+template<typename Arg0, typename Arg1>
+class symbolic_add : public symbolic_index_base<symbolic_add<Arg0,Arg1> >
+{
+public:
+ symbolic_add(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
+
+ template<typename T>
+ Index eval(const T& values) const { return m_arg0.eval(values) + m_arg1.eval(values); }
+protected:
+ Arg0 m_arg0;
+ Arg1 m_arg1;
+};
+
+template<typename Arg0, typename Arg1>
+class symbolic_product : public symbolic_index_base<symbolic_product<Arg0,Arg1> >
+{
+public:
+ symbolic_product(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
+
+ template<typename T>
+ Index eval(const T& values) const { return m_arg0.eval(values) * m_arg1.eval(values); }
+protected:
+ Arg0 m_arg0;
+ Arg1 m_arg1;
+};
+
+template<typename Arg0, typename Arg1>
+class symbolic_quotient : public symbolic_index_base<symbolic_quotient<Arg0,Arg1> >
+{
+public:
+ symbolic_quotient(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
+
+ template<typename T>
+ Index eval(const T& values) const { return m_arg0.eval(values) / m_arg1.eval(values); }
+protected:
+ Arg0 m_arg0;
+ Arg1 m_arg1;
+};
+
+struct symb_last_tag {};
+
+static const symbolic_value<symb_last_tag> last;
+static const symbolic_add<symbolic_value<symb_last_tag>,symbolic_value_wrapper> end(last+1);
//--------------------------------------------------------------------------------
// integral constant
@@ -116,34 +254,30 @@ protected:
IncrType m_incr;
};
-template<typename T> struct cleanup_slice_type { typedef Index type; };
-template<> struct cleanup_slice_type<last_t> { typedef last_t type; };
-template<> struct cleanup_slice_type<shifted_last> { typedef shifted_last type; };
-template<> struct cleanup_slice_type<end_t> { typedef end_t type; };
-template<> struct cleanup_slice_type<shifted_end> { typedef shifted_end type; };
-template<int N> struct cleanup_slice_type<fix_t<N> > { typedef fix_t<N> type; };
-template<int N> struct cleanup_slice_type<fix_t<N> (*)() > { typedef fix_t<N> type; };
+template<typename T> struct cleanup_seq_type { typedef T type; };
+template<int N> struct cleanup_seq_type<fix_t<N> > { typedef fix_t<N> type; };
+template<int N> struct cleanup_seq_type<fix_t<N> (*)() > { typedef fix_t<N> type; };
template<typename FirstType,typename LastType>
-ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type >
-seq(FirstType f, LastType l) {
- return ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type>(f,l);
+ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type >
+seq_legacy(FirstType f, LastType l) {
+ return ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type>(f,l);
}
template<typename FirstType,typename LastType,typename IncrType>
-ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type,typename cleanup_slice_type<IncrType>::type >
-seq(FirstType f, LastType l, IncrType s) {
- return ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type,typename cleanup_slice_type<IncrType>::type>(f,l,typename cleanup_slice_type<IncrType>::type(s));
+ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type,typename cleanup_seq_type<IncrType>::type >
+seq_legacy(FirstType f, LastType l, IncrType s) {
+ return ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type,typename cleanup_seq_type<IncrType>::type>(f,l,typename cleanup_seq_type<IncrType>::type(s));
}
template<typename FirstType=Index,typename SizeType=Index,typename IncrType=fix_t<1> >
-class ArithemeticSequenceProxyWithSize
+class ArithemeticSequence
{
public:
- ArithemeticSequenceProxyWithSize(FirstType first, SizeType size) : m_first(first), m_size(size) {}
- ArithemeticSequenceProxyWithSize(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {}
+ ArithemeticSequence(FirstType first, SizeType size) : m_first(first), m_size(size) {}
+ ArithemeticSequence(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {}
enum {
SizeAtCompileTime = get_compile_time<SizeType>::value,
@@ -165,18 +299,30 @@ protected:
template<typename FirstType,typename SizeType,typename IncrType>
-ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type,typename cleanup_slice_type<IncrType>::type >
+ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type,typename cleanup_seq_type<IncrType>::type >
seqN(FirstType first, SizeType size, IncrType incr) {
- return ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type,typename cleanup_slice_type<IncrType>::type>(first,size,incr);
+ return ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type,typename cleanup_seq_type<IncrType>::type>(first,size,incr);
}
template<typename FirstType,typename SizeType>
-ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type >
+ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type >
seqN(FirstType first, SizeType size) {
- return ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type>(first,size);
+ return ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type>(first,size);
}
+template<typename FirstType,typename LastType>
+auto seq(FirstType f, LastType l) -> decltype(seqN(f,(l-f+1)))
+{
+ return seqN(f,(l-f+1));
+}
+template<typename FirstType,typename LastType, typename IncrType>
+auto seq(FirstType f, LastType l, IncrType incr)
+ -> decltype(seqN(f,(l-f+typename cleanup_seq_type<IncrType>::type(incr))/typename cleanup_seq_type<IncrType>::type(incr),typename cleanup_seq_type<IncrType>::type(incr)))
+{
+ typedef typename cleanup_seq_type<IncrType>::type CleanedIncrType;
+ return seqN(f,(l-f+CleanedIncrType(incr))/CleanedIncrType(incr),CleanedIncrType(incr));
+}
namespace internal {
@@ -214,7 +360,7 @@ struct get_compile_time_incr<ArithemeticSequenceProxyWithBounds<FirstType,LastTy
};
template<typename FirstType,typename SizeType,typename IncrType>
-struct get_compile_time_incr<ArithemeticSequenceProxyWithSize<FirstType,SizeType,IncrType> > {
+struct get_compile_time_incr<ArithemeticSequence<FirstType,SizeType,IncrType> > {
enum { value = get_compile_time<IncrType,DynamicIndex>::value };
};
@@ -258,6 +404,17 @@ Index symbolic2value(shifted_last x, Index size) { return size+x.offset-1; }
Index symbolic2value(end_t, Index size) { return size; }
Index symbolic2value(shifted_end x, Index size) { return size+x.offset; }
+template<int N>
+fix_t<N> symbolic2value(fix_t<N> x, Index /*size*/) { return x; }
+
+template<typename Derived>
+Index symbolic2value(const symbolic_index_base<Derived> &x, Index size)
+{
+ Index h=x.derived().eval(symbolic_value_pair<symb_last_tag>(size-1));
+ return x.derived().eval(symbolic_value_pair<symb_last_tag>(size-1));
+}
+
+
// Convert a symbolic range into a usable one (i.e., remove last/end "keywords")
template<typename FirstType,typename LastType,typename IncrType>
struct MakeIndexing<ArithemeticSequenceProxyWithBounds<FirstType,LastType,IncrType> > {
@@ -270,14 +427,21 @@ ArithemeticSequenceProxyWithBounds<Index,Index,IncrType> make_indexing(const Ari
}
// Convert a symbolic span into a usable one (i.e., remove last/end "keywords")
+template<typename T>
+struct make_size_type {
+ typedef typename internal::conditional<is_symbolic<T>::value, Index, T>::type type;
+};
+
template<typename FirstType,typename SizeType,typename IncrType>
-struct MakeIndexing<ArithemeticSequenceProxyWithSize<FirstType,SizeType,IncrType> > {
- typedef ArithemeticSequenceProxyWithSize<Index,SizeType,IncrType> type;
+struct MakeIndexing<ArithemeticSequence<FirstType,SizeType,IncrType> > {
+ typedef ArithemeticSequence<Index,typename make_size_type<SizeType>::type,IncrType> type;
};
template<typename FirstType,typename SizeType,typename IncrType>
-ArithemeticSequenceProxyWithSize<Index,SizeType,IncrType> make_indexing(const ArithemeticSequenceProxyWithSize<FirstType,SizeType,IncrType>& ids, Index size) {
- return ArithemeticSequenceProxyWithSize<Index,SizeType,IncrType>(symbolic2value(ids.firstObject(),size),ids.sizeObject(),ids.incrObject());
+ArithemeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>
+make_indexing(const ArithemeticSequence<FirstType,SizeType,IncrType>& ids, Index size) {
+ return ArithemeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>(
+ symbolic2value(ids.firstObject(),size),symbolic2value(ids.sizeObject(),size),ids.incrObject());
}
// Convert a symbolic 'all' into a usable range
diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp
index 23ad2d743..25a25499c 100644
--- a/test/indexed_view.cpp
+++ b/test/indexed_view.cpp
@@ -139,6 +139,11 @@ void check_indexed_view()
VERIFY_IS_EQUAL( (A(eii, eii)).InnerStrideAtCompileTime, 0);
VERIFY_IS_EQUAL( (A(eii, eii)).OuterStrideAtCompileTime, 0);
+ VERIFY_IS_APPROX( A(seq(n-1,2,-2), seqN(n-1-6,4)), A(seq(last,2,-2), seqN(last-6,4)) );
+ VERIFY_IS_APPROX( A(seq(n-1-6,n-1-2), seqN(n-1-6,4)), A(seq(last-6,last-2), seqN(6+last-6-6,4)) );
+ VERIFY_IS_APPROX( A(seq((n-1)/2,(n)/2+3), seqN(2,4)), A(seq(last/2,(last+1)/2+3), seqN(last+2-last,4)) );
+ VERIFY_IS_APPROX( A(seq(n-2,2,-2), seqN(n-8,4)), A(seq(end-2,2,-2), seqN(end-8,4)) );
+
#if EIGEN_HAS_CXX11
VERIFY( (A(all, std::array<int,4>{{1,3,2,4}})).ColsAtCompileTime == 4);