aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-25 11:04:14 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-25 11:04:14 -0700
commit58026905ae4a608abac33f59a782beae590a8371 (patch)
treed85fcd6bc009a0be687b35106fdc0b9ea9d27ae5 /unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
parented783872ab9040bab52f7b142458f168f662e3f0 (diff)
Added support for statically known lists of pairs of indices
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h215
1 files changed, 188 insertions, 27 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
index 4f0f4fd75..7aebd6f28 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
@@ -10,6 +10,22 @@
#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
+/*namespace Eigen {
+
+template <typename Index> struct IndexPair {
+ constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair() : first(0), second(0) {}
+ constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair(Index f, Index s) : first(f), second(s) {}
+
+ EIGEN_DEVICE_FUNC void set(IndexPair<Index> val) {
+ first = val.first;
+ second = val.second;
+ }
+
+ Index first;
+ Index second;
+};
+}*/
+
#if EIGEN_HAS_CONSTEXPR && EIGEN_HAS_VARIADIC_TEMPLATES
#define EIGEN_HAS_INDEX_LIST
@@ -45,6 +61,24 @@ struct type2index {
}
};
+// This can be used with IndexPairList to get compile-time constant pairs,
+// such as IndexPairList<type2indexpair<1,2>, type2indexpair<3,4>>().
+template <DenseIndex f, DenseIndex s>
+struct type2indexpair {
+ static const DenseIndex first = f;
+ static const DenseIndex second = s;
+
+ constexpr EIGEN_DEVICE_FUNC operator IndexPair<DenseIndex>() const {
+ return IndexPair<DenseIndex>(f, s);
+ }
+
+ EIGEN_DEVICE_FUNC void set(const IndexPair<DenseIndex>& val) {
+ eigen_assert(val.first == f);
+ eigen_assert(val.second == s);
+ }
+};
+
+
template<DenseIndex n> struct NumTraits<type2index<n> >
{
typedef DenseIndex Real;
@@ -73,6 +107,16 @@ EIGEN_DEVICE_FUNC void update_value(type2index<n>& val, DenseIndex new_val) {
}
template <typename T>
+EIGEN_DEVICE_FUNC void update_value(T& val, IndexPair<DenseIndex> new_val) {
+ val = new_val;
+}
+template <DenseIndex f, DenseIndex s>
+EIGEN_DEVICE_FUNC void update_value(type2indexpair<f, s>& val, IndexPair<DenseIndex> new_val) {
+ val.set(new_val);
+}
+
+
+template <typename T>
struct is_compile_time_constant {
static constexpr bool value = false;
};
@@ -94,7 +138,22 @@ struct is_compile_time_constant<const type2index<idx>& > {
static constexpr bool value = true;
};
-
+template <DenseIndex f, DenseIndex s>
+struct is_compile_time_constant<type2indexpair<f, s> > {
+ static constexpr bool value = true;
+};
+template <DenseIndex f, DenseIndex s>
+struct is_compile_time_constant<const type2indexpair<f, s> > {
+ static constexpr bool value = true;
+};
+template <DenseIndex f, DenseIndex s>
+struct is_compile_time_constant<type2indexpair<f, s>& > {
+ static constexpr bool value = true;
+};
+template <DenseIndex f, DenseIndex s>
+struct is_compile_time_constant<const type2indexpair<f, s>& > {
+ static constexpr bool value = true;
+};
template<typename... T>
@@ -184,31 +243,32 @@ template <typename T, typename... O>
-template <DenseIndex Idx>
+template <DenseIndex Idx, typename ValueT>
struct tuple_coeff {
template <typename... T>
- EIGEN_DEVICE_FUNC static constexpr DenseIndex get(const DenseIndex i, const IndexTuple<T...>& t) {
- return array_get<Idx>(t) * (i == Idx) + tuple_coeff<Idx-1>::get(i, t) * (i != Idx);
+ EIGEN_DEVICE_FUNC static constexpr ValueT get(const DenseIndex i, const IndexTuple<T...>& t) {
+ // return array_get<Idx>(t) * (i == Idx) + tuple_coeff<Idx-1>::get(i, t) * (i != Idx);
+ return (i == Idx ? array_get<Idx>(t) : tuple_coeff<Idx-1, ValueT>::get(i, t));
}
template <typename... T>
- EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple<T...>& t, const DenseIndex value) {
+ EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple<T...>& t, const ValueT& value) {
if (i == Idx) {
update_value(array_get<Idx>(t), value);
} else {
- tuple_coeff<Idx-1>::set(i, t, value);
+ tuple_coeff<Idx-1, ValueT>::set(i, t, value);
}
}
template <typename... T>
EIGEN_DEVICE_FUNC static constexpr bool value_known_statically(const DenseIndex i, const IndexTuple<T...>& t) {
return ((i == Idx) & is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value) ||
- tuple_coeff<Idx-1>::value_known_statically(i, t);
+ tuple_coeff<Idx-1, ValueT>::value_known_statically(i, t);
}
template <typename... T>
EIGEN_DEVICE_FUNC static constexpr bool values_up_to_known_statically(const IndexTuple<T...>& t) {
return is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value &&
- tuple_coeff<Idx-1>::values_up_to_known_statically(t);
+ tuple_coeff<Idx-1, ValueT>::values_up_to_known_statically(t);
}
template <typename... T>
@@ -216,19 +276,19 @@ struct tuple_coeff {
return is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value &&
is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value &&
array_get<Idx>(t) > array_get<Idx-1>(t) &&
- tuple_coeff<Idx-1>::values_up_to_statically_known_to_increase(t);
+ tuple_coeff<Idx-1, ValueT>::values_up_to_statically_known_to_increase(t);
}
};
-template <>
-struct tuple_coeff<0> {
+template <typename ValueT>
+struct tuple_coeff<0, ValueT> {
template <typename... T>
- EIGEN_DEVICE_FUNC static constexpr DenseIndex get(const DenseIndex i, const IndexTuple<T...>& t) {
+ EIGEN_DEVICE_FUNC static constexpr ValueT get(const DenseIndex /*i*/, const IndexTuple<T...>& t) {
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
- return array_get<0>(t) * (i == 0);
+ return array_get<0>(t)/* * (i == 0)*/;
}
template <typename... T>
- EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple<T...>& t, const DenseIndex value) {
+ EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple<T...>& t, const ValueT value) {
eigen_assert (i == 0);
update_value(array_get<0>(t), value);
}
@@ -254,13 +314,13 @@ struct tuple_coeff<0> {
template<typename FirstType, typename... OtherTypes>
struct IndexList : internal::IndexTuple<FirstType, OtherTypes...> {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const {
- return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::get(i, *this);
+ return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::get(i, *this);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex get(const DenseIndex i) const {
- return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::get(i, *this);
+ return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::get(i, *this);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) {
- return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::set(i, *this, value);
+ return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::set(i, *this, value);
}
EIGEN_DEVICE_FUNC constexpr IndexList(const internal::IndexTuple<FirstType, OtherTypes...>& other) : internal::IndexTuple<FirstType, OtherTypes...>(other) { }
@@ -268,14 +328,14 @@ struct IndexList : internal::IndexTuple<FirstType, OtherTypes...> {
EIGEN_DEVICE_FUNC constexpr IndexList() : internal::IndexTuple<FirstType, OtherTypes...>() { }
EIGEN_DEVICE_FUNC constexpr bool value_known_statically(const DenseIndex i) const {
- return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::value_known_statically(i, *this);
+ return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::value_known_statically(i, *this);
}
EIGEN_DEVICE_FUNC constexpr bool all_values_known_statically() const {
- return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::values_up_to_known_statically(*this);
+ return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::values_up_to_known_statically(*this);
}
EIGEN_DEVICE_FUNC constexpr bool values_statically_known_to_increase() const {
- return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::values_up_to_statically_known_to_increase(*this);
+ return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::values_up_to_statically_known_to_increase(*this);
}
};
@@ -286,6 +346,23 @@ constexpr IndexList<FirstType, OtherTypes...> make_index_list(FirstType val1, Ot
}
+template<typename FirstType, typename... OtherTypes>
+struct IndexPairList : internal::IndexTuple<FirstType, OtherTypes...> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr IndexPair<DenseIndex> operator[] (const DenseIndex i) const {
+ return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, IndexPair<DenseIndex>>::get(i, *this);
+ }
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const IndexPair<DenseIndex> value) {
+ return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...>>::value-1, IndexPair<DenseIndex> >::set(i, *this, value);
+ }
+
+ EIGEN_DEVICE_FUNC constexpr IndexPairList(const internal::IndexTuple<FirstType, OtherTypes...>& other) : internal::IndexTuple<FirstType, OtherTypes...>(other) { }
+ EIGEN_DEVICE_FUNC constexpr IndexPairList() : internal::IndexTuple<FirstType, OtherTypes...>() { }
+
+ EIGEN_DEVICE_FUNC constexpr bool value_known_statically(const DenseIndex i) const {
+ return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::value_known_statically(i, *this);
+ }
+};
+
namespace internal {
template<typename FirstType, typename... OtherTypes> size_t array_prod(const IndexList<FirstType, OtherTypes...>& sizes) {
@@ -303,6 +380,13 @@ template<typename FirstType, typename... OtherTypes> struct array_size<const Ind
static const size_t value = array_size<IndexTuple<FirstType, OtherTypes...> >::value;
};
+template<typename FirstType, typename... OtherTypes> struct array_size<IndexPairList<FirstType, OtherTypes...> > {
+ static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
+};
+template<typename FirstType, typename... OtherTypes> struct array_size<const IndexPairList<FirstType, OtherTypes...> > {
+ static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
+};
+
template<DenseIndex N, typename FirstType, typename... OtherTypes> EIGEN_DEVICE_FUNC constexpr DenseIndex array_get(IndexList<FirstType, OtherTypes...>& a) {
return IndexTupleExtractor<N, FirstType, OtherTypes...>::get_val(a);
}
@@ -472,6 +556,57 @@ struct index_statically_lt_impl<const IndexList<FirstType, OtherTypes...> > {
}
};
+
+
+template <typename Tx>
+struct index_pair_first_statically_eq_impl {
+ EIGEN_DEVICE_FUNC static constexpr bool run(DenseIndex, DenseIndex) {
+ return false;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_pair_first_statically_eq_impl<IndexPairList<FirstType, OtherTypes...> > {
+ EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) {
+ return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &
+ (IndexPairList<FirstType, OtherTypes...>()[i].first == value);
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_pair_first_statically_eq_impl<const IndexPairList<FirstType, OtherTypes...> > {
+ EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) {
+ return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &
+ (IndexPairList<FirstType, OtherTypes...>()[i].first == value);
+ }
+};
+
+
+
+template <typename Tx>
+struct index_pair_second_statically_eq_impl {
+ EIGEN_DEVICE_FUNC static constexpr bool run(DenseIndex, DenseIndex) {
+ return false;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_pair_second_statically_eq_impl<IndexPairList<FirstType, OtherTypes...> > {
+ EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) {
+ return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &
+ (IndexPairList<FirstType, OtherTypes...>()[i].second == value);
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_pair_second_statically_eq_impl<const IndexPairList<FirstType, OtherTypes...> > {
+ EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) {
+ return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &
+ (IndexPairList<FirstType, OtherTypes...>()[i].second == value);
+ }
+};
+
+
} // end namespace internal
} // end namespace Eigen
@@ -482,53 +617,69 @@ namespace internal {
template <typename T>
struct index_known_statically_impl {
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const DenseIndex) {
+ static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(const DenseIndex) {
return false;
}
};
template <typename T>
struct all_indices_known_statically_impl {
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run() {
+ static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run() {
return false;
}
};
template <typename T>
struct indices_statically_known_to_increase_impl {
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run() {
+ static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run() {
return false;
}
};
template <typename T>
struct index_statically_eq_impl {
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) {
+ static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
return false;
}
};
template <typename T>
struct index_statically_ne_impl {
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) {
+ static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
return false;
}
};
template <typename T>
struct index_statically_gt_impl {
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) {
+ static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
return false;
}
};
template <typename T>
struct index_statically_lt_impl {
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) {
+ static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
+ return false;
+ }
+};
+
+template <typename Tx>
+struct index_pair_first_statically_eq_impl {
+ static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
+ return false;
+ }
+};
+
+template <typename Tx>
+struct index_pair_second_statically_eq_impl {
+ static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
return false;
}
};
+
+
} // end namespace internal
} // end namespace Eigen
@@ -572,6 +723,16 @@ static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bool index_statically_lt(DenseIndex i,
return index_statically_lt_impl<T>::run(i, value);
}
+template <typename T>
+static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bool index_pair_first_statically_eq(DenseIndex i, DenseIndex value) {
+ return index_pair_first_statically_eq_impl<T>::run(i, value);
+}
+
+template <typename T>
+static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bool index_pair_second_statically_eq(DenseIndex i, DenseIndex value) {
+ return index_pair_second_statically_eq_impl<T>::run(i, value);
+}
+
} // end namespace internal
} // end namespace Eigen