aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h133
1 files changed, 130 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
index 11590b474..732c6b344 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
@@ -37,8 +37,7 @@ template <typename Index> struct IndexPair {
Index second;
};
-
-// Boiler plate code
+// Boilerplate code
namespace internal {
template<std::size_t n, typename Dimension> struct dget {
@@ -110,6 +109,11 @@ struct Sizes : internal::numeric_list<std::size_t, Indices...> {
}
};
+template <typename std::size_t... Indices>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<Indices...>&) {
+ return Sizes<Indices...>::total_size;
+}
+
#else
template <std::size_t n>
@@ -136,9 +140,21 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
// todo: add assertion
}
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
+ template <typename... DenseIndex> Sizes(DenseIndex... indices) { }
explicit Sizes(std::initializer_list<std::size_t> l) {
// todo: add assertion
}
+#else
+ EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex i0) {
+ }
+ EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex i0, const DenseIndex i1) {
+ }
+ EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) {
+ }
+ EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) {
+ }
+ EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) {
+ }
#endif
template <typename T> Sizes& operator = (const T& other) {
@@ -156,9 +172,14 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
}
};
+template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) {
+ return Sizes<V1, V2, V3, V4, V5>::total_size;
+};
+
#endif
-// Boiler plate
+// Boilerplate
namespace internal {
template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
struct tensor_index_linearization_helper
@@ -243,6 +264,112 @@ struct DSizes : array<DenseIndex, NumDims> {
};
+
+
+// Boilerplate
+namespace internal {
+template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
+struct tensor_vsize_index_linearization_helper
+{
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const& dimensions)
+ {
+ return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
+ array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) *
+ tensor_vsize_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
+ }
+};
+
+template<typename Index, std::size_t NumIndices, bool RowMajor>
+struct tensor_vsize_index_linearization_helper<Index, NumIndices, 0, RowMajor>
+{
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const&)
+ {
+ return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
+ }
+};
+} // end namespace internal
+
+template <typename DenseIndex>
+struct VSizes : std::vector<DenseIndex> {
+ typedef std::vector<DenseIndex> Base;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() const {
+ return internal::array_prod(*static_cast<const Base*>(this));
+ }
+
+ EIGEN_DEVICE_FUNC VSizes() { }
+ EIGEN_DEVICE_FUNC explicit VSizes(const std::vector<DenseIndex>& a) : Base(a) { }
+
+ template <std::size_t NumDims>
+ EIGEN_DEVICE_FUNC explicit VSizes(const array<DenseIndex, NumDims>& a) {
+ this->resize(NumDims);
+ for (int i = 0; i < NumDims; ++i) {
+ (*this)[i] = a[i];
+ }
+ }
+
+ EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0) {
+ this->resize(1);
+ (*this)[0] = i0;
+ }
+ EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0, const DenseIndex i1) {
+ this->resize(2);
+ (*this)[0] = i0;
+ (*this)[1] = i1;
+ }
+ EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) {
+ this->resize(3);
+ (*this)[0] = i0;
+ (*this)[1] = i1;
+ (*this)[2] = i2;
+ }
+ EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) {
+ this->resize(4);
+ (*this)[0] = i0;
+ (*this)[1] = i1;
+ (*this)[2] = i2;
+ (*this)[3] = i3;
+ }
+ EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) {
+ this->resize(5);
+ (*this)[0] = i0;
+ (*this)[1] = i1;
+ (*this)[2] = i2;
+ (*this)[3] = i3;
+ (*this)[4] = i4;
+ }
+
+ VSizes& operator = (const std::vector<DenseIndex>& other) {
+ *static_cast<Base*>(this) = other;
+ return *this;
+ }
+
+ // A constexpr would be so much better here
+ template <std::size_t NumDims>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const {
+ return internal::tensor_vsize_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(indices, *static_cast<const Base*>(this));
+ }
+ template <std::size_t NumDims>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const {
+ return internal::tensor_vsize_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(indices, *static_cast<const Base*>(this));
+ }
+};
+
+
+// Boilerplate
+namespace internal {
+template <typename DenseIndex>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex array_prod(const VSizes<DenseIndex>& sizes) {
+ DenseIndex total_size = 1;
+ for (int i = 0; i < sizes.size(); ++i) {
+ total_size *= sizes[i];
+ }
+ return total_size;
+}
+}
+
namespace internal {
template <typename DenseIndex, std::size_t NumDims> struct array_size<const DSizes<DenseIndex, NumDims> > {