aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-06-29 14:04:15 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-06-29 14:04:15 -0700
commitf0ce85b757ce237d763d7751bda61901e78d5dc8 (patch)
tree8d463b14daa68ea6e624c61e5cac612ed23d89d2 /unsupported/Eigen/CXX11
parent670c71d906a4f0adc7edf266c996183ae8e4a2cc (diff)
Improved support for fixed size tensors
Diffstat (limited to 'unsupported/Eigen/CXX11')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/Tensor.h22
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h51
2 files changed, 73 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
index 4dbdbfb3e..24953ec94 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
@@ -375,6 +375,28 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
resize(dims);
}
+#ifndef EIGEN_EMULATE_CXX11_META_H
+ template <typename std::ptrdiff_t... Indices>
+ EIGEN_DEVICE_FUNC
+ void resize(const Sizes<Indices...>& dimensions) {
+ array<Index, NumIndices> dims;
+ for (int i = 0; i < NumIndices; ++i) {
+ dims[i] = dimensions[i];
+ }
+ resize(dims);
+ }
+#else
+ template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
+ EIGEN_DEVICE_FUNC
+ void resize(const Sizes<V1, V2, V3, V4, V5>& dimensions) {
+ array<Index, NumIndices> dims;
+ for (int i = 0; i < NumIndices; ++i) {
+ dims[i] = dimensions[i];
+ }
+ resize(dims);
+ }
+#endif
+
protected:
bool checkIndexRange(const array<Index, NumIndices>& indices) const
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
index 836daea65..5928f0b0c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
@@ -69,6 +69,31 @@ struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMaj
}
};
+template<typename Index, std::size_t n>
+struct fixed_size_tensor_index_extraction_helper
+{
+ template <typename Dimensions> EIGEN_DEVICE_FUNC
+ static inline Index run(const Index index,
+ const Dimensions& dimensions)
+ {
+ const Index mult = (index == n) ? 1 : 0;
+ return array_get<n>(dimensions) * mult +
+ fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions);
+ }
+};
+
+template<typename Index>
+struct fixed_size_tensor_index_extraction_helper<Index, 0>
+{
+ template <typename Dimensions> EIGEN_DEVICE_FUNC
+ static inline Index run(const Index index,
+ const Dimensions& dimensions)
+ {
+ const Index mult = (index == 0) ? 1 : 0;
+ return array_get<0>(dimensions) * mult;
+ }
+};
+
} // end namespace internal
@@ -99,6 +124,10 @@ struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
}
#endif
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const int index) const {
+ return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count - 1>::run(index, *this);
+ }
+
template <typename T> Sizes& operator = (const T& /*other*/) {
// add assertion failure if the size of other is different
return *this;
@@ -114,10 +143,12 @@ struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
}
};
+namespace internal {
template <typename std::ptrdiff_t... Indices>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) {
return Sizes<Indices...>::total_size;
}
+}
#else
@@ -166,6 +197,24 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
}
#endif
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex operator[] (const int index) const {
+ switch (index) {
+ case 0:
+ return internal::get<0, Base>::value;
+ case 1:
+ return internal::get<1, Base>::value;
+ case 2:
+ return internal::get<2, Base>::value;
+ case 3:
+ return internal::get<3, Base>::value;
+ case 4:
+ return internal::get<4, Base>::value;
+ default:
+ eigen_assert(false && "index overflow");
+ return static_cast<std::size_t>(-1);
+ }
+ }
+
template <typename T> Sizes& operator = (const T&) {
// to do: check the size of other
return *this;
@@ -181,10 +230,12 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
}
};
+namespace internal {
template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) {
return Sizes<V1, V2, V3, V4, V5>::total_size;
}
+}
#endif