From 58026905ae4a608abac33f59a782beae590a8371 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 25 May 2016 11:04:14 -0700 Subject: Added support for statically known lists of pairs of indices --- .../Eigen/CXX11/src/Tensor/TensorIndexList.h | 215 ++++++++++++++++++--- 1 file changed, 188 insertions(+), 27 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h index 4f0f4fd75..7aebd6f28 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h @@ -10,6 +10,22 @@ #ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H #define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H +/*namespace Eigen { + +template struct IndexPair { + constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair() : first(0), second(0) {} + constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair(Index f, Index s) : first(f), second(s) {} + + EIGEN_DEVICE_FUNC void set(IndexPair val) { + first = val.first; + second = val.second; + } + + Index first; + Index second; +}; +}*/ + #if EIGEN_HAS_CONSTEXPR && EIGEN_HAS_VARIADIC_TEMPLATES #define EIGEN_HAS_INDEX_LIST @@ -45,6 +61,24 @@ struct type2index { } }; +// This can be used with IndexPairList to get compile-time constant pairs, +// such as IndexPairList, type2indexpair<3,4>>(). +template +struct type2indexpair { + static const DenseIndex first = f; + static const DenseIndex second = s; + + constexpr EIGEN_DEVICE_FUNC operator IndexPair() const { + return IndexPair(f, s); + } + + EIGEN_DEVICE_FUNC void set(const IndexPair& val) { + eigen_assert(val.first == f); + eigen_assert(val.second == s); + } +}; + + template struct NumTraits > { typedef DenseIndex Real; @@ -72,6 +106,16 @@ EIGEN_DEVICE_FUNC void update_value(type2index& val, DenseIndex new_val) { val.set(new_val); } +template +EIGEN_DEVICE_FUNC void update_value(T& val, IndexPair new_val) { + val = new_val; +} +template +EIGEN_DEVICE_FUNC void update_value(type2indexpair& val, IndexPair new_val) { + val.set(new_val); +} + + template struct is_compile_time_constant { static constexpr bool value = false; @@ -94,7 +138,22 @@ struct is_compile_time_constant& > { static constexpr bool value = true; }; - +template +struct is_compile_time_constant > { + static constexpr bool value = true; +}; +template +struct is_compile_time_constant > { + static constexpr bool value = true; +}; +template +struct is_compile_time_constant& > { + static constexpr bool value = true; +}; +template +struct is_compile_time_constant& > { + static constexpr bool value = true; +}; template @@ -184,31 +243,32 @@ template -template +template struct tuple_coeff { template - EIGEN_DEVICE_FUNC static constexpr DenseIndex get(const DenseIndex i, const IndexTuple& t) { - return array_get(t) * (i == Idx) + tuple_coeff::get(i, t) * (i != Idx); + EIGEN_DEVICE_FUNC static constexpr ValueT get(const DenseIndex i, const IndexTuple& t) { + // return array_get(t) * (i == Idx) + tuple_coeff::get(i, t) * (i != Idx); + return (i == Idx ? array_get(t) : tuple_coeff::get(i, t)); } template - EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple& t, const DenseIndex value) { + EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple& t, const ValueT& value) { if (i == Idx) { update_value(array_get(t), value); } else { - tuple_coeff::set(i, t, value); + tuple_coeff::set(i, t, value); } } template EIGEN_DEVICE_FUNC static constexpr bool value_known_statically(const DenseIndex i, const IndexTuple& t) { return ((i == Idx) & is_compile_time_constant::ValType>::value) || - tuple_coeff::value_known_statically(i, t); + tuple_coeff::value_known_statically(i, t); } template EIGEN_DEVICE_FUNC static constexpr bool values_up_to_known_statically(const IndexTuple& t) { return is_compile_time_constant::ValType>::value && - tuple_coeff::values_up_to_known_statically(t); + tuple_coeff::values_up_to_known_statically(t); } template @@ -216,19 +276,19 @@ struct tuple_coeff { return is_compile_time_constant::ValType>::value && is_compile_time_constant::ValType>::value && array_get(t) > array_get(t) && - tuple_coeff::values_up_to_statically_known_to_increase(t); + tuple_coeff::values_up_to_statically_known_to_increase(t); } }; -template <> -struct tuple_coeff<0> { +template +struct tuple_coeff<0, ValueT> { template - EIGEN_DEVICE_FUNC static constexpr DenseIndex get(const DenseIndex i, const IndexTuple& t) { + EIGEN_DEVICE_FUNC static constexpr ValueT get(const DenseIndex /*i*/, const IndexTuple& t) { // eigen_assert (i == 0); // gcc fails to compile assertions in constexpr - return array_get<0>(t) * (i == 0); + return array_get<0>(t)/* * (i == 0)*/; } template - EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple& t, const DenseIndex value) { + EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple& t, const ValueT value) { eigen_assert (i == 0); update_value(array_get<0>(t), value); } @@ -254,13 +314,13 @@ struct tuple_coeff<0> { template struct IndexList : internal::IndexTuple { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const { - return internal::tuple_coeff >::value-1>::get(i, *this); + return internal::tuple_coeff >::value-1, DenseIndex>::get(i, *this); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex get(const DenseIndex i) const { - return internal::tuple_coeff >::value-1>::get(i, *this); + return internal::tuple_coeff >::value-1, DenseIndex>::get(i, *this); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) { - return internal::tuple_coeff >::value-1>::set(i, *this, value); + return internal::tuple_coeff >::value-1, DenseIndex>::set(i, *this, value); } EIGEN_DEVICE_FUNC constexpr IndexList(const internal::IndexTuple& other) : internal::IndexTuple(other) { } @@ -268,14 +328,14 @@ struct IndexList : internal::IndexTuple { EIGEN_DEVICE_FUNC constexpr IndexList() : internal::IndexTuple() { } EIGEN_DEVICE_FUNC constexpr bool value_known_statically(const DenseIndex i) const { - return internal::tuple_coeff >::value-1>::value_known_statically(i, *this); + return internal::tuple_coeff >::value-1, DenseIndex>::value_known_statically(i, *this); } EIGEN_DEVICE_FUNC constexpr bool all_values_known_statically() const { - return internal::tuple_coeff >::value-1>::values_up_to_known_statically(*this); + return internal::tuple_coeff >::value-1, DenseIndex>::values_up_to_known_statically(*this); } EIGEN_DEVICE_FUNC constexpr bool values_statically_known_to_increase() const { - return internal::tuple_coeff >::value-1>::values_up_to_statically_known_to_increase(*this); + return internal::tuple_coeff >::value-1, DenseIndex>::values_up_to_statically_known_to_increase(*this); } }; @@ -286,6 +346,23 @@ constexpr IndexList make_index_list(FirstType val1, Ot } +template +struct IndexPairList : internal::IndexTuple { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr IndexPair operator[] (const DenseIndex i) const { + return internal::tuple_coeff >::value-1, IndexPair>::get(i, *this); + } + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const IndexPair value) { + return internal::tuple_coeff>::value-1, IndexPair >::set(i, *this, value); + } + + EIGEN_DEVICE_FUNC constexpr IndexPairList(const internal::IndexTuple& other) : internal::IndexTuple(other) { } + EIGEN_DEVICE_FUNC constexpr IndexPairList() : internal::IndexTuple() { } + + EIGEN_DEVICE_FUNC constexpr bool value_known_statically(const DenseIndex i) const { + return internal::tuple_coeff >::value-1, DenseIndex>::value_known_statically(i, *this); + } +}; + namespace internal { template size_t array_prod(const IndexList& sizes) { @@ -303,6 +380,13 @@ template struct array_size >::value; }; +template struct array_size > { + static const size_t value = std::tuple_size >::value; +}; +template struct array_size > { + static const size_t value = std::tuple_size >::value; +}; + template EIGEN_DEVICE_FUNC constexpr DenseIndex array_get(IndexList& a) { return IndexTupleExtractor::get_val(a); } @@ -472,6 +556,57 @@ struct index_statically_lt_impl > { } }; + + +template +struct index_pair_first_statically_eq_impl { + EIGEN_DEVICE_FUNC static constexpr bool run(DenseIndex, DenseIndex) { + return false; + } +}; + +template +struct index_pair_first_statically_eq_impl > { + EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) { + return IndexPairList().value_known_statically(i) & + (IndexPairList()[i].first == value); + } +}; + +template +struct index_pair_first_statically_eq_impl > { + EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) { + return IndexPairList().value_known_statically(i) & + (IndexPairList()[i].first == value); + } +}; + + + +template +struct index_pair_second_statically_eq_impl { + EIGEN_DEVICE_FUNC static constexpr bool run(DenseIndex, DenseIndex) { + return false; + } +}; + +template +struct index_pair_second_statically_eq_impl > { + EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) { + return IndexPairList().value_known_statically(i) & + (IndexPairList()[i].second == value); + } +}; + +template +struct index_pair_second_statically_eq_impl > { + EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) { + return IndexPairList().value_known_statically(i) & + (IndexPairList()[i].second == value); + } +}; + + } // end namespace internal } // end namespace Eigen @@ -482,53 +617,69 @@ namespace internal { template struct index_known_statically_impl { - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const DenseIndex) { + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(const DenseIndex) { return false; } }; template struct all_indices_known_statically_impl { - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run() { + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run() { return false; } }; template struct indices_statically_known_to_increase_impl { - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run() { + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run() { return false; } }; template struct index_statically_eq_impl { - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) { + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) { return false; } }; template struct index_statically_ne_impl { - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) { + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) { return false; } }; template struct index_statically_gt_impl { - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) { + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) { return false; } }; template struct index_statically_lt_impl { - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) { + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) { + return false; + } +}; + +template +struct index_pair_first_statically_eq_impl { + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) { + return false; + } +}; + +template +struct index_pair_second_statically_eq_impl { + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) { return false; } }; + + } // end namespace internal } // end namespace Eigen @@ -572,6 +723,16 @@ static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bool index_statically_lt(DenseIndex i, return index_statically_lt_impl::run(i, value); } +template +static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bool index_pair_first_statically_eq(DenseIndex i, DenseIndex value) { + return index_pair_first_statically_eq_impl::run(i, value); +} + +template +static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bool index_pair_second_statically_eq(DenseIndex i, DenseIndex value) { + return index_pair_second_statically_eq_impl::run(i, value); +} + } // end namespace internal } // end namespace Eigen -- cgit v1.2.3