diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-11-12 22:25:38 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-11-12 22:25:38 -0800 |
commit | c2d1074932ae92a001eadb27e9f85eaf2de187b9 (patch) | |
tree | df08df6eedb23a4a27d51d103978298b3dfa62e6 /unsupported | |
parent | cb37f818ca6e8dfc9d81343882401e3671531d1b (diff) |
Added support for static list of indices
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/Tensor | 1 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h | 264 | ||||
-rw-r--r-- | unsupported/test/CMakeLists.txt | 1 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_index_list.cpp | 133 |
4 files changed, 399 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index c36db96ec..44d5a4d82 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -43,6 +43,7 @@ #include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h new file mode 100644 index 000000000..010221e74 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h @@ -0,0 +1,264 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H +#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H + +#if __cplusplus > 199711L + +namespace Eigen { + +/** \internal + * + * \class TensorIndexList + * \ingroup CXX11_Tensor_Module + * + * \brief Set of classes used to encode a set of Tensor dimensions/indices. + * + * The indices in the list can be known at compile time or at runtime. A mix + * of static and dynamic indices can also be provided if needed. The tensor + * code will attempt to take advantage of the indices that are known at + * compile time to optimize the code it generates. + * + * This functionality requires a c++11 compliant compiler. If your compiler + * is older you need to use arrays of indices instead. + * + * Several examples are provided in the cxx11_tensor_index_list.cpp file. + * + * \sa Tensor + */ + +template <DenseIndex n> +struct type2index { + static const DenseIndex value = n; + constexpr operator DenseIndex() const { return n; } + void set(DenseIndex val) { + eigen_assert(val == n); + } +}; + +namespace internal { +template <typename T> +void update_value(T& val, DenseIndex new_val) { + val = new_val; +} +template <DenseIndex n> +void update_value(type2index<n>& val, DenseIndex new_val) { + val.set(new_val); +} + +template <typename T> +struct is_compile_time_constant { + static constexpr bool value = false; +}; + +template <DenseIndex idx> +struct is_compile_time_constant<type2index<idx> > { + static constexpr bool value = true; +}; +template <DenseIndex idx> +struct is_compile_time_constant<const type2index<idx> > { + static constexpr bool value = true; +}; +template <DenseIndex idx> +struct is_compile_time_constant<type2index<idx>& > { + static constexpr bool value = true; +}; +template <DenseIndex idx> +struct is_compile_time_constant<const type2index<idx>& > { + static constexpr bool value = true; +}; + +template <DenseIndex Idx> +struct tuple_coeff { + template <typename... T> + static constexpr DenseIndex get(const DenseIndex i, const std::tuple<T...>& t) { + return std::get<Idx>(t) * (i == Idx) + tuple_coeff<Idx-1>::get(i, t) * (i != Idx); + } + template <typename... T> + static void set(const DenseIndex i, std::tuple<T...>& t, const DenseIndex value) { + if (i == Idx) { + update_value(std::get<Idx>(t), value); + } else { + tuple_coeff<Idx-1>::set(i, t, value); + } + } + + template <typename... T> + static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) { + return ((i == Idx) & is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value) || + tuple_coeff<Idx-1>::value_known_statically(i, t); + } +}; + +template <> +struct tuple_coeff<0> { + template <typename... T> + static constexpr DenseIndex get(const DenseIndex i, const std::tuple<T...>& t) { + // eigen_assert (i == 0); // gcc fails to compile assertions in constexpr + return std::get<0>(t) * (i == 0); + } + template <typename... T> + static void set(const DenseIndex i, std::tuple<T...>& t, const DenseIndex value) { + eigen_assert (i == 0); + update_value(std::get<0>(t), value); + } + template <typename... T> + static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) { + // eigen_assert (i == 0); // gcc fails to compile assertions in constexpr + return is_compile_time_constant<typename std::tuple_element<0, std::tuple<T...> >::type>::value & (i == 0); + } +}; +} // namespace internal + + +template<typename FirstType, typename... OtherTypes> +struct IndexList : std::tuple<FirstType, OtherTypes...> { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const { + return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::get(i, *this); + } + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) { + return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::set(i, *this, value); + } + + constexpr IndexList(const std::tuple<FirstType, OtherTypes...>& other) : std::tuple<FirstType, OtherTypes...>(other) { } + constexpr IndexList() : std::tuple<FirstType, OtherTypes...>() { } + + constexpr bool value_known_statically(const DenseIndex i) const { + return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::value_known_statically(i, *this); + } +}; + + +template<typename FirstType, typename... OtherTypes> +constexpr IndexList<FirstType, OtherTypes...> make_index_list(FirstType val1, OtherTypes... other_vals) { + return std::make_tuple(val1, other_vals...); +} + + +namespace internal { + +template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > { + static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value; +}; +template<typename FirstType, typename... OtherTypes> struct array_size<const IndexList<FirstType, OtherTypes...> > { + static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value; +}; + +template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(IndexList<FirstType, OtherTypes...>& a) { + return std::get<n>(a); +} +template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(const IndexList<FirstType, OtherTypes...>& a) { + return std::get<n>(a); +} + +template <typename T> +struct index_known_statically { + constexpr bool operator() (DenseIndex) const { + return false; + } +}; + +template <typename FirstType, typename... OtherTypes> +struct index_known_statically<IndexList<FirstType, OtherTypes...> > { + constexpr bool operator() (const DenseIndex i) const { + return IndexList<FirstType, OtherTypes...>().value_known_statically(i); + } +}; + +template <typename FirstType, typename... OtherTypes> +struct index_known_statically<const IndexList<FirstType, OtherTypes...> > { + constexpr bool operator() (const DenseIndex i) const { + return IndexList<FirstType, OtherTypes...>().value_known_statically(i); + } +}; + +template <typename Tx> +struct index_statically_eq { + constexpr bool operator() (DenseIndex, DenseIndex) const { + return false; + } +}; + +template <typename FirstType, typename... OtherTypes> +struct index_statically_eq<IndexList<FirstType, OtherTypes...> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList<FirstType, OtherTypes...>().value_known_statically(i) & + IndexList<FirstType, OtherTypes...>()[i] == value; + } +}; + +template <typename FirstType, typename... OtherTypes> +struct index_statically_eq<const IndexList<FirstType, OtherTypes...> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList<FirstType, OtherTypes...>().value_known_statically(i) & + IndexList<FirstType, OtherTypes...>()[i] == value; + } +}; + +template <typename T> +struct index_statically_ne { + constexpr bool operator() (DenseIndex, DenseIndex) const { + return false; + } +}; + +template <typename FirstType, typename... OtherTypes> +struct index_statically_ne<IndexList<FirstType, OtherTypes...> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList<FirstType, OtherTypes...>().value_known_statically(i) & + IndexList<FirstType, OtherTypes...>()[i] != value; + } +}; + +template <typename FirstType, typename... OtherTypes> +struct index_statically_ne<const IndexList<FirstType, OtherTypes...> > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList<FirstType, OtherTypes...>().value_known_statically(i) & + IndexList<FirstType, OtherTypes...>()[i] != value; + } +}; + + +} // end namespace internal +} // end namespace Eigen + +#else + +namespace Eigen { +namespace internal { + +// No C++11 support +template <typename T> +struct index_known_statically { + EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex) const{ + return false; + } +}; + +template <typename T> +struct index_statically_eq { + EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{ + return false; + } +}; + +template <typename T> +struct index_statically_ne { + EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{ + return false; + } +}; + +} // end namespace internal +} // end namespace Eigen + +#endif + +#endif // EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt index 6b8ed2826..181f06fc7 100644 --- a/unsupported/test/CMakeLists.txt +++ b/unsupported/test/CMakeLists.txt @@ -102,6 +102,7 @@ if(EIGEN_TEST_CXX11) ei_add_test(cxx11_tensor_symmetry "-std=c++0x") ei_add_test(cxx11_tensor_assign "-std=c++0x") ei_add_test(cxx11_tensor_dimension "-std=c++0x") + ei_add_test(cxx11_tensor_index_list "-std=c++0x") ei_add_test(cxx11_tensor_comparisons "-std=c++0x") ei_add_test(cxx11_tensor_contraction "-std=c++0x") ei_add_test(cxx11_tensor_convolution "-std=c++0x") diff --git a/unsupported/test/cxx11_tensor_index_list.cpp b/unsupported/test/cxx11_tensor_index_list.cpp new file mode 100644 index 000000000..6a103cab1 --- /dev/null +++ b/unsupported/test/cxx11_tensor_index_list.cpp @@ -0,0 +1,133 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include <Eigen/CXX11/Tensor> + + +static void test_static_index_list() +{ + Tensor<float, 4> tensor(2,3,5,7); + tensor.setRandom(); + + constexpr auto reduction_axis = make_index_list(0, 1, 2); + VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0); + VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1); + VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 0); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 2); + + EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_axis) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::array_get<1>(reduction_axis) == 1), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_axis) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE); + + Tensor<float, 1> result = tensor.sum(reduction_axis); + for (int i = 0; i < result.size(); ++i) { + float expected = 0.0f; + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + for (int l = 0; l < 5; ++l) { + expected += tensor(j,k,l,i); + } + } + } + VERIFY_IS_APPROX(result(i), expected); + } +} + + +static void test_dynamic_index_list() +{ + Tensor<float, 4> tensor(2,3,5,7); + tensor.setRandom(); + + int dim1 = 2; + int dim2 = 1; + int dim3 = 0; + + auto reduction_axis = make_index_list(dim1, dim2, dim3); + + VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 2); + VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1); + VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 0); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 2); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 0); + + Tensor<float, 1> result = tensor.sum(reduction_axis); + for (int i = 0; i < result.size(); ++i) { + float expected = 0.0f; + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + for (int l = 0; l < 5; ++l) { + expected += tensor(j,k,l,i); + } + } + } + VERIFY_IS_APPROX(result(i), expected); + } +} + +static void test_mixed_index_list() +{ + Tensor<float, 4> tensor(2,3,5,7); + tensor.setRandom(); + + int dim2 = 1; + int dim4 = 3; + + auto reduction_axis = make_index_list(0, dim2, 2, dim4); + + VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0); + VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1); + VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2); + VERIFY_IS_EQUAL(internal::array_get<3>(reduction_axis), 3); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 0); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 2); + VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[3]), 3); + + typedef IndexList<type2index<0>, int, type2index<2>, int> ReductionIndices; + ReductionIndices reduction_indices; + reduction_indices.set(1, 1); + reduction_indices.set(3, 3); + EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_indices) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_indices) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE); + + + Tensor<float, 1> result1 = tensor.sum(reduction_axis); + Tensor<float, 1> result2 = tensor.sum(reduction_indices); + + float expected = 0.0f; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + expected += tensor(i,j,k,l); + } + } + } + } + VERIFY_IS_APPROX(result1(0), expected); + VERIFY_IS_APPROX(result2(0), expected); +} + + +void test_cxx11_tensor_index_list() +{ + CALL_SUBTEST(test_static_index_list()); + CALL_SUBTEST(test_dynamic_index_list()); + CALL_SUBTEST(test_mixed_index_list()); +} |