From c2d1074932ae92a001eadb27e9f85eaf2de187b9 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 12 Nov 2014 22:25:38 -0800 Subject: Added support for static list of indices --- .../Eigen/CXX11/src/Tensor/TensorIndexList.h | 264 +++++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100644 unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h new file mode 100644 index 000000000..010221e74 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h @@ -0,0 +1,264 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H +#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H + +#if __cplusplus > 199711L + +namespace Eigen { + +/** \internal + * + * \class TensorIndexList + * \ingroup CXX11_Tensor_Module + * + * \brief Set of classes used to encode a set of Tensor dimensions/indices. + * + * The indices in the list can be known at compile time or at runtime. A mix + * of static and dynamic indices can also be provided if needed. The tensor + * code will attempt to take advantage of the indices that are known at + * compile time to optimize the code it generates. + * + * This functionality requires a c++11 compliant compiler. If your compiler + * is older you need to use arrays of indices instead. + * + * Several examples are provided in the cxx11_tensor_index_list.cpp file. + * + * \sa Tensor + */ + +template +struct type2index { + static const DenseIndex value = n; + constexpr operator DenseIndex() const { return n; } + void set(DenseIndex val) { + eigen_assert(val == n); + } +}; + +namespace internal { +template +void update_value(T& val, DenseIndex new_val) { + val = new_val; +} +template +void update_value(type2index& val, DenseIndex new_val) { + val.set(new_val); +} + +template +struct is_compile_time_constant { + static constexpr bool value = false; +}; + +template +struct is_compile_time_constant > { + static constexpr bool value = true; +}; +template +struct is_compile_time_constant > { + static constexpr bool value = true; +}; +template +struct is_compile_time_constant& > { + static constexpr bool value = true; +}; +template +struct is_compile_time_constant& > { + static constexpr bool value = true; +}; + +template +struct tuple_coeff { + template + static constexpr DenseIndex get(const DenseIndex i, const std::tuple& t) { + return std::get(t) * (i == Idx) + tuple_coeff::get(i, t) * (i != Idx); + } + template + static void set(const DenseIndex i, std::tuple& t, const DenseIndex value) { + if (i == Idx) { + update_value(std::get(t), value); + } else { + tuple_coeff::set(i, t, value); + } + } + + template + static constexpr bool value_known_statically(const DenseIndex i, const std::tuple& t) { + return ((i == Idx) & is_compile_time_constant >::type>::value) || + tuple_coeff::value_known_statically(i, t); + } +}; + +template <> +struct tuple_coeff<0> { + template + static constexpr DenseIndex get(const DenseIndex i, const std::tuple& t) { + // eigen_assert (i == 0); // gcc fails to compile assertions in constexpr + return std::get<0>(t) * (i == 0); + } + template + static void set(const DenseIndex i, std::tuple& t, const DenseIndex value) { + eigen_assert (i == 0); + update_value(std::get<0>(t), value); + } + template + static constexpr bool value_known_statically(const DenseIndex i, const std::tuple& t) { + // eigen_assert (i == 0); // gcc fails to compile assertions in constexpr + return is_compile_time_constant >::type>::value & (i == 0); + } +}; +} // namespace internal + + +template +struct IndexList : std::tuple { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const { + return internal::tuple_coeff >::value-1>::get(i, *this); + } + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) { + return internal::tuple_coeff >::value-1>::set(i, *this, value); + } + + constexpr IndexList(const std::tuple& other) : std::tuple(other) { } + constexpr IndexList() : std::tuple() { } + + constexpr bool value_known_statically(const DenseIndex i) const { + return internal::tuple_coeff >::value-1>::value_known_statically(i, *this); + } +}; + + +template +constexpr IndexList make_index_list(FirstType val1, OtherTypes... other_vals) { + return std::make_tuple(val1, other_vals...); +} + + +namespace internal { + +template struct array_size > { + static const size_t value = std::tuple_size >::value; +}; +template struct array_size > { + static const size_t value = std::tuple_size >::value; +}; + +template constexpr DenseIndex array_get(IndexList& a) { + return std::get(a); +} +template constexpr DenseIndex array_get(const IndexList& a) { + return std::get(a); +} + +template +struct index_known_statically { + constexpr bool operator() (DenseIndex) const { + return false; + } +}; + +template +struct index_known_statically > { + constexpr bool operator() (const DenseIndex i) const { + return IndexList().value_known_statically(i); + } +}; + +template +struct index_known_statically > { + constexpr bool operator() (const DenseIndex i) const { + return IndexList().value_known_statically(i); + } +}; + +template +struct index_statically_eq { + constexpr bool operator() (DenseIndex, DenseIndex) const { + return false; + } +}; + +template +struct index_statically_eq > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList().value_known_statically(i) & + IndexList()[i] == value; + } +}; + +template +struct index_statically_eq > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList().value_known_statically(i) & + IndexList()[i] == value; + } +}; + +template +struct index_statically_ne { + constexpr bool operator() (DenseIndex, DenseIndex) const { + return false; + } +}; + +template +struct index_statically_ne > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList().value_known_statically(i) & + IndexList()[i] != value; + } +}; + +template +struct index_statically_ne > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList().value_known_statically(i) & + IndexList()[i] != value; + } +}; + + +} // end namespace internal +} // end namespace Eigen + +#else + +namespace Eigen { +namespace internal { + +// No C++11 support +template +struct index_known_statically { + EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex) const{ + return false; + } +}; + +template +struct index_statically_eq { + EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{ + return false; + } +}; + +template +struct index_statically_ne { + EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{ + return false; + } +}; + +} // end namespace internal +} // end namespace Eigen + +#endif + +#endif // EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H -- cgit v1.2.3