diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-10-15 14:58:49 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-10-15 14:58:49 -0700 |
commit | de1e9f29f4db2c837ffb354c90f9e9fb7df05e85 (patch) | |
tree | 99832d8d52f1b46063a82c1c9133e9b598df2d1b /unsupported | |
parent | 6585efc55354b38c65de8c23599e99f3caaca843 (diff) |
Updated the custom indexing code: we can now use any container that provides the [] operator to index a tensor. Added unit tests to validate the use of std::map and a few more types as valid custom index containers
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/Tensor.h | 17 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h | 8 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_custom_index.cpp | 70 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_simple.cpp | 4 |
4 files changed, 77 insertions, 22 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h index 57d44baf9..3ac465d24 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h @@ -91,7 +91,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp #ifdef EIGEN_HAS_SFINAE template<typename CustomIndices> struct isOfNormalIndex{ - static const bool is_array = internal::is_base_of<array<Index, NumIndices>, CustomIndices >::value; + static const bool is_array = internal::is_base_of<array<Index, NumIndices>, CustomIndices>::value; static const bool is_int = NumTraits<CustomIndices>::IsInteger; static const bool value = is_array | is_int; }; @@ -120,11 +120,8 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) return coeff(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}}); } - - #endif - // normal indices EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const array<Index, NumIndices>& indices) const { @@ -137,7 +134,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp template<typename CustomIndices, EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const CustomIndices & indices) const + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(CustomIndices& indices) const { return coeff(internal::customIndices2Array<Index,NumIndices>(indices)); } @@ -171,7 +168,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp template<typename CustomIndices, EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const CustomIndices & indices) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(CustomIndices& indices) { return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices)); } @@ -219,7 +216,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp template<typename CustomIndices, EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const CustomIndices & indices) const + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(CustomIndices& indices) const { return coeff(internal::customIndices2Array<Index,NumIndices>(indices)); } @@ -286,7 +283,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp template<typename CustomIndices, EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const CustomIndices & indices) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(CustomIndices& indices) { return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices)); } @@ -441,9 +438,9 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp template<typename CustomDimension, EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomDimension>::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(const CustomDimension & dimensions) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(CustomDimension& dimensions) { - return coeffRef(internal::customIndices2Array<Index,NumIndices>(dimensions)); + resize(internal::customIndices2Array<Index,NumIndices>(dimensions)); } #endif diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h index d1efc1a87..07735fa5f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h @@ -82,15 +82,15 @@ namespace internal{ template<typename IndexType, Index... Is> EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - array<Index,sizeof...(Is)> customIndices2Array(const IndexType & idx, numeric_list<Index,Is...>) { - return { idx(Is)... }; + array<Index, sizeof...(Is)> customIndices2Array(IndexType& idx, numeric_list<Index, Is...>) { + return { idx[Is]... }; } /** Make an array (for index/dimensions) out of a custom index */ template<typename Index, int NumIndices, typename IndexType> EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - array<Index,NumIndices> customIndices2Array(const IndexType & idx) { - return customIndices2Array(idx, typename gen_numeric_list<Index,NumIndices>::type{}); + array<Index, NumIndices> customIndices2Array(IndexType& idx) { + return customIndices2Array(idx, typename gen_numeric_list<Index, NumIndices>::type{}); } diff --git a/unsupported/test/cxx11_tensor_custom_index.cpp b/unsupported/test/cxx11_tensor_custom_index.cpp index ff9545a7a..4528cc176 100644 --- a/unsupported/test/cxx11_tensor_custom_index.cpp +++ b/unsupported/test/cxx11_tensor_custom_index.cpp @@ -9,6 +9,7 @@ #include "main.h" #include <limits> +#include <map> #include <Eigen/Dense> #include <Eigen/CXX11/Tensor> @@ -17,22 +18,83 @@ using Eigen::Tensor; template <int DataLayout> -static void test_custom_index() { +static void test_map_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7); + tensor.setRandom(); + + using NormalIndex = DSizes<ptrdiff_t, 4>; + using CustomIndex = std::map<ptrdiff_t, ptrdiff_t>; + CustomIndex coeffC; + coeffC[0] = 1; + coeffC[1] = 2; + coeffC[2] = 4; + coeffC[3] = 1; + NormalIndex coeff(1,2,4,1); + VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); +#endif +} + + +template <int DataLayout> +static void test_matrix_as_index() +{ +#ifdef EIGEN_HAS_SFINAE Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7); tensor.setRandom(); using NormalIndex = DSizes<ptrdiff_t, 4>; - using CustomIndex = Matrix<unsigned int , 4, 1>; + using CustomIndex = Matrix<unsigned int, 4, 1>; CustomIndex coeffC(1,2,4,1); NormalIndex coeff(1,2,4,1); VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); +#endif +} + + +template <int DataLayout> +static void test_varlist_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7); + tensor.setRandom(); + + DSizes<ptrdiff_t, 4> coeff(1,2,4,1); + + VERIFY_IS_EQUAL(tensor.coeff({1,2,4,1}), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef({1,2,4,1}), tensor.coeffRef(coeff)); +#endif +} + + +template <int DataLayout> +static void test_sizes_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7); + tensor.setRandom(); + + DSizes<ptrdiff_t, 4> coeff(1,2,4,1); + Sizes<1,2,4,1> coeffC; + + VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); +#endif } void test_cxx11_tensor_custom_index() { - test_custom_index<ColMajor>(); - test_custom_index<RowMajor>(); + test_map_as_index<ColMajor>(); + test_map_as_index<RowMajor>(); + test_matrix_as_index<ColMajor>(); + test_matrix_as_index<RowMajor>(); + test_varlist_as_index<ColMajor>(); + test_varlist_as_index<RowMajor>(); + test_sizes_as_index<ColMajor>(); + test_sizes_as_index<RowMajor>(); } diff --git a/unsupported/test/cxx11_tensor_simple.cpp b/unsupported/test/cxx11_tensor_simple.cpp index 8cd2ab7fd..0ce92eed9 100644 --- a/unsupported/test/cxx11_tensor_simple.cpp +++ b/unsupported/test/cxx11_tensor_simple.cpp @@ -293,7 +293,3 @@ void test_cxx11_tensor_simple() CALL_SUBTEST(test_simple_assign()); CALL_SUBTEST(test_resize()); } - -/* - * kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle; - */ |