diff options
author | 2015-10-22 21:47:47 +0200 | |
---|---|---|
committer | 2015-10-22 21:47:47 +0200 | |
commit | ebc1af1683a31f75a71e15fef0703d552cf9512b (patch) | |
tree | b0be661effcf211d172cd3b8b297fd207a54cf96 | |
parent | 0eb46508e2e653218557e0ba9858926be9da04e9 (diff) | |
parent | 825146c8fd21cc31e96cef42a57b0bbf25f266de (diff) |
merge
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorMap.h | 8 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_map.cpp | 104 |
2 files changed, 110 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h index 2cb2bc7a6..55c289810 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h @@ -82,15 +82,19 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor } #endif - inline TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions) : m_data(dataPtr), m_dimensions(dimensions) { } template <typename Dimensions> - EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions) : m_data(dataPtr), m_dimensions(dimensions) { } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor) + : m_data(tensor.data()), m_dimensions(tensor.dimensions()) + { } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); } EIGEN_DEVICE_FUNC diff --git a/unsupported/test/cxx11_tensor_map.cpp b/unsupported/test/cxx11_tensor_map.cpp index 9cf2eb150..4c4e10df2 100644 --- a/unsupported/test/cxx11_tensor_map.cpp +++ b/unsupported/test/cxx11_tensor_map.cpp @@ -139,9 +139,113 @@ static void test_3d() } +static void test_from_tensor() +{ + Tensor<int, 3> mat1(2,3,7); + Tensor<int, 3, RowMajor> mat2(2,3,7); + + int val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + mat1(i,j,k) = val; + mat2(i,j,k) = val; + val++; + } + } + } + + TensorMap<Tensor<int, 3>> mat3(mat1); + TensorMap<Tensor<int, 3, RowMajor>> mat4(mat2); + + VERIFY_IS_EQUAL(mat3.rank(), 3); + VERIFY_IS_EQUAL(mat3.size(), 2*3*7); + VERIFY_IS_EQUAL(mat3.dimension(0), 2); + VERIFY_IS_EQUAL(mat3.dimension(1), 3); + VERIFY_IS_EQUAL(mat3.dimension(2), 7); + + VERIFY_IS_EQUAL(mat4.rank(), 3); + VERIFY_IS_EQUAL(mat4.size(), 2*3*7); + VERIFY_IS_EQUAL(mat4.dimension(0), 2); + VERIFY_IS_EQUAL(mat4.dimension(1), 3); + VERIFY_IS_EQUAL(mat4.dimension(2), 7); + + val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(mat3(i,j,k), val); + VERIFY_IS_EQUAL(mat4(i,j,k), val); + val++; + } + } + } + + TensorFixedSize<int, Sizes<2,3,7>> mat5; + + val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + mat5(i,j,k) = val; + val++; + } + } + } + + TensorMap<TensorFixedSize<int, Sizes<2,3,7>>> mat6(mat5); + + VERIFY_IS_EQUAL(mat6.rank(), 3); + VERIFY_IS_EQUAL(mat6.size(), 2*3*7); + VERIFY_IS_EQUAL(mat6.dimension(0), 2); + VERIFY_IS_EQUAL(mat6.dimension(1), 3); + VERIFY_IS_EQUAL(mat6.dimension(2), 7); + + val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(mat6(i,j,k), val); + val++; + } + } + } +} + + +static int f(const TensorMap<Tensor<int, 3> >& tensor) { + Tensor<int, 1> result = tensor.sum(); + return result(0); +} + +static void test_casting() +{ + Tensor<int, 3> tensor(2,3,7); + + int val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + tensor(i,j,k) = val; + val++; + } + } + } + + TensorMap<Tensor<int, 3>> map(tensor); + int sum1 = f(map); + int sum2 = f(tensor); + + VERIFY_IS_EQUAL(sum1, sum2); + VERIFY_IS_EQUAL(sum1, 861); +} + void test_cxx11_tensor_map() { CALL_SUBTEST(test_1d()); CALL_SUBTEST(test_2d()); CALL_SUBTEST(test_3d()); + + CALL_SUBTEST(test_from_tensor()); + CALL_SUBTEST(test_casting()); } |