diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-29 14:04:15 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-29 14:04:15 -0700 |
commit | f0ce85b757ce237d763d7751bda61901e78d5dc8 (patch) | |
tree | 8d463b14daa68ea6e624c61e5cac612ed23d89d2 /unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h | |
parent | 670c71d906a4f0adc7edf266c996183ae8e4a2cc (diff) |
Improved support for fixed size tensors
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h index 836daea65..5928f0b0c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h @@ -69,6 +69,31 @@ struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMaj } }; +template<typename Index, std::size_t n> +struct fixed_size_tensor_index_extraction_helper +{ + template <typename Dimensions> EIGEN_DEVICE_FUNC + static inline Index run(const Index index, + const Dimensions& dimensions) + { + const Index mult = (index == n) ? 1 : 0; + return array_get<n>(dimensions) * mult + + fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions); + } +}; + +template<typename Index> +struct fixed_size_tensor_index_extraction_helper<Index, 0> +{ + template <typename Dimensions> EIGEN_DEVICE_FUNC + static inline Index run(const Index index, + const Dimensions& dimensions) + { + const Index mult = (index == 0) ? 1 : 0; + return array_get<0>(dimensions) * mult; + } +}; + } // end namespace internal @@ -99,6 +124,10 @@ struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> { } #endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const int index) const { + return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count - 1>::run(index, *this); + } + template <typename T> Sizes& operator = (const T& /*other*/) { // add assertion failure if the size of other is different return *this; @@ -114,10 +143,12 @@ struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> { } }; +namespace internal { template <typename std::ptrdiff_t... Indices> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) { return Sizes<Indices...>::total_size; } +} #else @@ -166,6 +197,24 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0 } #endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex operator[] (const int index) const { + switch (index) { + case 0: + return internal::get<0, Base>::value; + case 1: + return internal::get<1, Base>::value; + case 2: + return internal::get<2, Base>::value; + case 3: + return internal::get<3, Base>::value; + case 4: + return internal::get<4, Base>::value; + default: + eigen_assert(false && "index overflow"); + return static_cast<std::size_t>(-1); + } + } + template <typename T> Sizes& operator = (const T&) { // to do: check the size of other return *this; @@ -181,10 +230,12 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0 } }; +namespace internal { template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) { return Sizes<V1, V2, V3, V4, V5>::total_size; } +} #endif |