aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-10-15 14:58:49 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-10-15 14:58:49 -0700
commitde1e9f29f4db2c837ffb354c90f9e9fb7df05e85 (patch)
tree99832d8d52f1b46063a82c1c9133e9b598df2d1b /unsupported
parent6585efc55354b38c65de8c23599e99f3caaca843 (diff)
Updated the custom indexing code: we can now use any container that provides the [] operator to index a tensor. Added unit tests to validate the use of std::map and a few more types as valid custom index containers
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/Tensor.h17
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h8
-rw-r--r--unsupported/test/cxx11_tensor_custom_index.cpp70
-rw-r--r--unsupported/test/cxx11_tensor_simple.cpp4
4 files changed, 77 insertions, 22 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
index 57d44baf9..3ac465d24 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
@@ -91,7 +91,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
#ifdef EIGEN_HAS_SFINAE
template<typename CustomIndices>
struct isOfNormalIndex{
- static const bool is_array = internal::is_base_of<array<Index, NumIndices>, CustomIndices >::value;
+ static const bool is_array = internal::is_base_of<array<Index, NumIndices>, CustomIndices>::value;
static const bool is_int = NumTraits<CustomIndices>::IsInteger;
static const bool value = is_array | is_int;
};
@@ -120,11 +120,8 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
return coeff(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
}
-
-
#endif
-
// normal indices
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const array<Index, NumIndices>& indices) const
{
@@ -137,7 +134,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
template<typename CustomIndices,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const CustomIndices & indices) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(CustomIndices& indices) const
{
return coeff(internal::customIndices2Array<Index,NumIndices>(indices));
}
@@ -171,7 +168,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
template<typename CustomIndices,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const CustomIndices & indices)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(CustomIndices& indices)
{
return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices));
}
@@ -219,7 +216,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
template<typename CustomIndices,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const CustomIndices & indices) const
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(CustomIndices& indices) const
{
return coeff(internal::customIndices2Array<Index,NumIndices>(indices));
}
@@ -286,7 +283,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
template<typename CustomIndices,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const CustomIndices & indices)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(CustomIndices& indices)
{
return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices));
}
@@ -441,9 +438,9 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
template<typename CustomDimension,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomDimension>::value) )
>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(const CustomDimension & dimensions)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(CustomDimension& dimensions)
{
- return coeffRef(internal::customIndices2Array<Index,NumIndices>(dimensions));
+ resize(internal::customIndices2Array<Index,NumIndices>(dimensions));
}
#endif
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h
index d1efc1a87..07735fa5f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h
@@ -82,15 +82,15 @@ namespace internal{
template<typename IndexType, Index... Is>
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- array<Index,sizeof...(Is)> customIndices2Array(const IndexType & idx, numeric_list<Index,Is...>) {
- return { idx(Is)... };
+ array<Index, sizeof...(Is)> customIndices2Array(IndexType& idx, numeric_list<Index, Is...>) {
+ return { idx[Is]... };
}
/** Make an array (for index/dimensions) out of a custom index */
template<typename Index, int NumIndices, typename IndexType>
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- array<Index,NumIndices> customIndices2Array(const IndexType & idx) {
- return customIndices2Array(idx, typename gen_numeric_list<Index,NumIndices>::type{});
+ array<Index, NumIndices> customIndices2Array(IndexType& idx) {
+ return customIndices2Array(idx, typename gen_numeric_list<Index, NumIndices>::type{});
}
diff --git a/unsupported/test/cxx11_tensor_custom_index.cpp b/unsupported/test/cxx11_tensor_custom_index.cpp
index ff9545a7a..4528cc176 100644
--- a/unsupported/test/cxx11_tensor_custom_index.cpp
+++ b/unsupported/test/cxx11_tensor_custom_index.cpp
@@ -9,6 +9,7 @@
#include "main.h"
#include <limits>
+#include <map>
#include <Eigen/Dense>
#include <Eigen/CXX11/Tensor>
@@ -17,22 +18,83 @@ using Eigen::Tensor;
template <int DataLayout>
-static void test_custom_index() {
+static void test_map_as_index()
+{
+#ifdef EIGEN_HAS_SFINAE
+ Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
+ tensor.setRandom();
+
+ using NormalIndex = DSizes<ptrdiff_t, 4>;
+ using CustomIndex = std::map<ptrdiff_t, ptrdiff_t>;
+ CustomIndex coeffC;
+ coeffC[0] = 1;
+ coeffC[1] = 2;
+ coeffC[2] = 4;
+ coeffC[3] = 1;
+ NormalIndex coeff(1,2,4,1);
+ VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff));
+ VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff));
+#endif
+}
+
+
+template <int DataLayout>
+static void test_matrix_as_index()
+{
+#ifdef EIGEN_HAS_SFINAE
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
tensor.setRandom();
using NormalIndex = DSizes<ptrdiff_t, 4>;
- using CustomIndex = Matrix<unsigned int , 4, 1>;
+ using CustomIndex = Matrix<unsigned int, 4, 1>;
CustomIndex coeffC(1,2,4,1);
NormalIndex coeff(1,2,4,1);
VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff));
VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff));
+#endif
+}
+
+
+template <int DataLayout>
+static void test_varlist_as_index()
+{
+#ifdef EIGEN_HAS_SFINAE
+ Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
+ tensor.setRandom();
+
+ DSizes<ptrdiff_t, 4> coeff(1,2,4,1);
+
+ VERIFY_IS_EQUAL(tensor.coeff({1,2,4,1}), tensor.coeff(coeff));
+ VERIFY_IS_EQUAL(tensor.coeffRef({1,2,4,1}), tensor.coeffRef(coeff));
+#endif
+}
+
+
+template <int DataLayout>
+static void test_sizes_as_index()
+{
+#ifdef EIGEN_HAS_SFINAE
+ Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
+ tensor.setRandom();
+
+ DSizes<ptrdiff_t, 4> coeff(1,2,4,1);
+ Sizes<1,2,4,1> coeffC;
+
+ VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff));
+ VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff));
+#endif
}
void test_cxx11_tensor_custom_index() {
- test_custom_index<ColMajor>();
- test_custom_index<RowMajor>();
+ test_map_as_index<ColMajor>();
+ test_map_as_index<RowMajor>();
+ test_matrix_as_index<ColMajor>();
+ test_matrix_as_index<RowMajor>();
+ test_varlist_as_index<ColMajor>();
+ test_varlist_as_index<RowMajor>();
+ test_sizes_as_index<ColMajor>();
+ test_sizes_as_index<RowMajor>();
}
diff --git a/unsupported/test/cxx11_tensor_simple.cpp b/unsupported/test/cxx11_tensor_simple.cpp
index 8cd2ab7fd..0ce92eed9 100644
--- a/unsupported/test/cxx11_tensor_simple.cpp
+++ b/unsupported/test/cxx11_tensor_simple.cpp
@@ -293,7 +293,3 @@ void test_cxx11_tensor_simple()
CALL_SUBTEST(test_simple_assign());
CALL_SUBTEST(test_resize());
}
-
-/*
- * kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle;
- */