aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-06-29 14:04:15 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-06-29 14:04:15 -0700
commitf0ce85b757ce237d763d7751bda61901e78d5dc8 (patch)
tree8d463b14daa68ea6e624c61e5cac612ed23d89d2 /unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
parent670c71d906a4f0adc7edf266c996183ae8e4a2cc (diff)
Improved support for fixed size tensors
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h51
1 files changed, 51 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
index 836daea65..5928f0b0c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
@@ -69,6 +69,31 @@ struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMaj
}
};
+template<typename Index, std::size_t n>
+struct fixed_size_tensor_index_extraction_helper
+{
+ template <typename Dimensions> EIGEN_DEVICE_FUNC
+ static inline Index run(const Index index,
+ const Dimensions& dimensions)
+ {
+ const Index mult = (index == n) ? 1 : 0;
+ return array_get<n>(dimensions) * mult +
+ fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions);
+ }
+};
+
+template<typename Index>
+struct fixed_size_tensor_index_extraction_helper<Index, 0>
+{
+ template <typename Dimensions> EIGEN_DEVICE_FUNC
+ static inline Index run(const Index index,
+ const Dimensions& dimensions)
+ {
+ const Index mult = (index == 0) ? 1 : 0;
+ return array_get<0>(dimensions) * mult;
+ }
+};
+
} // end namespace internal
@@ -99,6 +124,10 @@ struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
}
#endif
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const int index) const {
+ return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count - 1>::run(index, *this);
+ }
+
template <typename T> Sizes& operator = (const T& /*other*/) {
// add assertion failure if the size of other is different
return *this;
@@ -114,10 +143,12 @@ struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
}
};
+namespace internal {
template <typename std::ptrdiff_t... Indices>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) {
return Sizes<Indices...>::total_size;
}
+}
#else
@@ -166,6 +197,24 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
}
#endif
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex operator[] (const int index) const {
+ switch (index) {
+ case 0:
+ return internal::get<0, Base>::value;
+ case 1:
+ return internal::get<1, Base>::value;
+ case 2:
+ return internal::get<2, Base>::value;
+ case 3:
+ return internal::get<3, Base>::value;
+ case 4:
+ return internal::get<4, Base>::value;
+ default:
+ eigen_assert(false && "index overflow");
+ return static_cast<std::size_t>(-1);
+ }
+ }
+
template <typename T> Sizes& operator = (const T&) {
// to do: check the size of other
return *this;
@@ -181,10 +230,12 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
}
};
+namespace internal {
template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) {
return Sizes<V1, V2, V3, V4, V5>::total_size;
}
+}
#endif