aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-11-12 22:25:38 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-11-12 22:25:38 -0800
commitc2d1074932ae92a001eadb27e9f85eaf2de187b9 (patch)
treedf08df6eedb23a4a27d51d103978298b3dfa62e6
parentcb37f818ca6e8dfc9d81343882401e3671531d1b (diff)
Added support for static list of indices
-rw-r--r--unsupported/Eigen/CXX11/Tensor1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h264
-rw-r--r--unsupported/test/CMakeLists.txt1
-rw-r--r--unsupported/test/cxx11_tensor_index_list.cpp133
4 files changed, 399 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor
index c36db96ec..44d5a4d82 100644
--- a/unsupported/Eigen/CXX11/Tensor
+++ b/unsupported/Eigen/CXX11/Tensor
@@ -43,6 +43,7 @@
#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
new file mode 100644
index 000000000..010221e74
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
@@ -0,0 +1,264 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
+#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
+
+#if __cplusplus > 199711L
+
+namespace Eigen {
+
+/** \internal
+ *
+ * \class TensorIndexList
+ * \ingroup CXX11_Tensor_Module
+ *
+ * \brief Set of classes used to encode a set of Tensor dimensions/indices.
+ *
+ * The indices in the list can be known at compile time or at runtime. A mix
+ * of static and dynamic indices can also be provided if needed. The tensor
+ * code will attempt to take advantage of the indices that are known at
+ * compile time to optimize the code it generates.
+ *
+ * This functionality requires a c++11 compliant compiler. If your compiler
+ * is older you need to use arrays of indices instead.
+ *
+ * Several examples are provided in the cxx11_tensor_index_list.cpp file.
+ *
+ * \sa Tensor
+ */
+
+template <DenseIndex n>
+struct type2index {
+ static const DenseIndex value = n;
+ constexpr operator DenseIndex() const { return n; }
+ void set(DenseIndex val) {
+ eigen_assert(val == n);
+ }
+};
+
+namespace internal {
+template <typename T>
+void update_value(T& val, DenseIndex new_val) {
+ val = new_val;
+}
+template <DenseIndex n>
+void update_value(type2index<n>& val, DenseIndex new_val) {
+ val.set(new_val);
+}
+
+template <typename T>
+struct is_compile_time_constant {
+ static constexpr bool value = false;
+};
+
+template <DenseIndex idx>
+struct is_compile_time_constant<type2index<idx> > {
+ static constexpr bool value = true;
+};
+template <DenseIndex idx>
+struct is_compile_time_constant<const type2index<idx> > {
+ static constexpr bool value = true;
+};
+template <DenseIndex idx>
+struct is_compile_time_constant<type2index<idx>& > {
+ static constexpr bool value = true;
+};
+template <DenseIndex idx>
+struct is_compile_time_constant<const type2index<idx>& > {
+ static constexpr bool value = true;
+};
+
+template <DenseIndex Idx>
+struct tuple_coeff {
+ template <typename... T>
+ static constexpr DenseIndex get(const DenseIndex i, const std::tuple<T...>& t) {
+ return std::get<Idx>(t) * (i == Idx) + tuple_coeff<Idx-1>::get(i, t) * (i != Idx);
+ }
+ template <typename... T>
+ static void set(const DenseIndex i, std::tuple<T...>& t, const DenseIndex value) {
+ if (i == Idx) {
+ update_value(std::get<Idx>(t), value);
+ } else {
+ tuple_coeff<Idx-1>::set(i, t, value);
+ }
+ }
+
+ template <typename... T>
+ static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) {
+ return ((i == Idx) & is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value) ||
+ tuple_coeff<Idx-1>::value_known_statically(i, t);
+ }
+};
+
+template <>
+struct tuple_coeff<0> {
+ template <typename... T>
+ static constexpr DenseIndex get(const DenseIndex i, const std::tuple<T...>& t) {
+ // eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
+ return std::get<0>(t) * (i == 0);
+ }
+ template <typename... T>
+ static void set(const DenseIndex i, std::tuple<T...>& t, const DenseIndex value) {
+ eigen_assert (i == 0);
+ update_value(std::get<0>(t), value);
+ }
+ template <typename... T>
+ static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) {
+ // eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
+ return is_compile_time_constant<typename std::tuple_element<0, std::tuple<T...> >::type>::value & (i == 0);
+ }
+};
+} // namespace internal
+
+
+template<typename FirstType, typename... OtherTypes>
+struct IndexList : std::tuple<FirstType, OtherTypes...> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const {
+ return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::get(i, *this);
+ }
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) {
+ return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::set(i, *this, value);
+ }
+
+ constexpr IndexList(const std::tuple<FirstType, OtherTypes...>& other) : std::tuple<FirstType, OtherTypes...>(other) { }
+ constexpr IndexList() : std::tuple<FirstType, OtherTypes...>() { }
+
+ constexpr bool value_known_statically(const DenseIndex i) const {
+ return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::value_known_statically(i, *this);
+ }
+};
+
+
+template<typename FirstType, typename... OtherTypes>
+constexpr IndexList<FirstType, OtherTypes...> make_index_list(FirstType val1, OtherTypes... other_vals) {
+ return std::make_tuple(val1, other_vals...);
+}
+
+
+namespace internal {
+
+template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > {
+ static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
+};
+template<typename FirstType, typename... OtherTypes> struct array_size<const IndexList<FirstType, OtherTypes...> > {
+ static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
+};
+
+template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(IndexList<FirstType, OtherTypes...>& a) {
+ return std::get<n>(a);
+}
+template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(const IndexList<FirstType, OtherTypes...>& a) {
+ return std::get<n>(a);
+}
+
+template <typename T>
+struct index_known_statically {
+ constexpr bool operator() (DenseIndex) const {
+ return false;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_known_statically<IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i);
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_known_statically<const IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i);
+ }
+};
+
+template <typename Tx>
+struct index_statically_eq {
+ constexpr bool operator() (DenseIndex, DenseIndex) const {
+ return false;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_statically_eq<IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
+ IndexList<FirstType, OtherTypes...>()[i] == value;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_statically_eq<const IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
+ IndexList<FirstType, OtherTypes...>()[i] == value;
+ }
+};
+
+template <typename T>
+struct index_statically_ne {
+ constexpr bool operator() (DenseIndex, DenseIndex) const {
+ return false;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_statically_ne<IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
+ IndexList<FirstType, OtherTypes...>()[i] != value;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_statically_ne<const IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
+ IndexList<FirstType, OtherTypes...>()[i] != value;
+ }
+};
+
+
+} // end namespace internal
+} // end namespace Eigen
+
+#else
+
+namespace Eigen {
+namespace internal {
+
+// No C++11 support
+template <typename T>
+struct index_known_statically {
+ EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex) const{
+ return false;
+ }
+};
+
+template <typename T>
+struct index_statically_eq {
+ EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
+ return false;
+ }
+};
+
+template <typename T>
+struct index_statically_ne {
+ EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
+ return false;
+ }
+};
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt
index 6b8ed2826..181f06fc7 100644
--- a/unsupported/test/CMakeLists.txt
+++ b/unsupported/test/CMakeLists.txt
@@ -102,6 +102,7 @@ if(EIGEN_TEST_CXX11)
ei_add_test(cxx11_tensor_symmetry "-std=c++0x")
ei_add_test(cxx11_tensor_assign "-std=c++0x")
ei_add_test(cxx11_tensor_dimension "-std=c++0x")
+ ei_add_test(cxx11_tensor_index_list "-std=c++0x")
ei_add_test(cxx11_tensor_comparisons "-std=c++0x")
ei_add_test(cxx11_tensor_contraction "-std=c++0x")
ei_add_test(cxx11_tensor_convolution "-std=c++0x")
diff --git a/unsupported/test/cxx11_tensor_index_list.cpp b/unsupported/test/cxx11_tensor_index_list.cpp
new file mode 100644
index 000000000..6a103cab1
--- /dev/null
+++ b/unsupported/test/cxx11_tensor_index_list.cpp
@@ -0,0 +1,133 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#include "main.h"
+
+#include <Eigen/CXX11/Tensor>
+
+
+static void test_static_index_list()
+{
+ Tensor<float, 4> tensor(2,3,5,7);
+ tensor.setRandom();
+
+ constexpr auto reduction_axis = make_index_list(0, 1, 2);
+ VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0);
+ VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1);
+ VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 0);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 2);
+
+ EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_axis) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((internal::array_get<1>(reduction_axis) == 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_axis) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ Tensor<float, 1> result = tensor.sum(reduction_axis);
+ for (int i = 0; i < result.size(); ++i) {
+ float expected = 0.0f;
+ for (int j = 0; j < 2; ++j) {
+ for (int k = 0; k < 3; ++k) {
+ for (int l = 0; l < 5; ++l) {
+ expected += tensor(j,k,l,i);
+ }
+ }
+ }
+ VERIFY_IS_APPROX(result(i), expected);
+ }
+}
+
+
+static void test_dynamic_index_list()
+{
+ Tensor<float, 4> tensor(2,3,5,7);
+ tensor.setRandom();
+
+ int dim1 = 2;
+ int dim2 = 1;
+ int dim3 = 0;
+
+ auto reduction_axis = make_index_list(dim1, dim2, dim3);
+
+ VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 2);
+ VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1);
+ VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 0);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 2);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 0);
+
+ Tensor<float, 1> result = tensor.sum(reduction_axis);
+ for (int i = 0; i < result.size(); ++i) {
+ float expected = 0.0f;
+ for (int j = 0; j < 2; ++j) {
+ for (int k = 0; k < 3; ++k) {
+ for (int l = 0; l < 5; ++l) {
+ expected += tensor(j,k,l,i);
+ }
+ }
+ }
+ VERIFY_IS_APPROX(result(i), expected);
+ }
+}
+
+static void test_mixed_index_list()
+{
+ Tensor<float, 4> tensor(2,3,5,7);
+ tensor.setRandom();
+
+ int dim2 = 1;
+ int dim4 = 3;
+
+ auto reduction_axis = make_index_list(0, dim2, 2, dim4);
+
+ VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0);
+ VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1);
+ VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2);
+ VERIFY_IS_EQUAL(internal::array_get<3>(reduction_axis), 3);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 0);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 2);
+ VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[3]), 3);
+
+ typedef IndexList<type2index<0>, int, type2index<2>, int> ReductionIndices;
+ ReductionIndices reduction_indices;
+ reduction_indices.set(1, 1);
+ reduction_indices.set(3, 3);
+ EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_indices) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_indices) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+
+ Tensor<float, 1> result1 = tensor.sum(reduction_axis);
+ Tensor<float, 1> result2 = tensor.sum(reduction_indices);
+
+ float expected = 0.0f;
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 5; ++k) {
+ for (int l = 0; l < 7; ++l) {
+ expected += tensor(i,j,k,l);
+ }
+ }
+ }
+ }
+ VERIFY_IS_APPROX(result1(0), expected);
+ VERIFY_IS_APPROX(result2(0), expected);
+}
+
+
+void test_cxx11_tensor_index_list()
+{
+ CALL_SUBTEST(test_static_index_list());
+ CALL_SUBTEST(test_dynamic_index_list());
+ CALL_SUBTEST(test_mixed_index_list());
+}