From 4cf7da63de0987dc8b49e5801f0cb79eb7fa6dbb Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 22 Oct 2015 11:48:02 -0700 Subject: Added a constructor to simplify the construction of tensormap from tensor --- unsupported/test/cxx11_tensor_map.cpp | 104 ++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) (limited to 'unsupported/test/cxx11_tensor_map.cpp') diff --git a/unsupported/test/cxx11_tensor_map.cpp b/unsupported/test/cxx11_tensor_map.cpp index 9cf2eb150..9ef935853 100644 --- a/unsupported/test/cxx11_tensor_map.cpp +++ b/unsupported/test/cxx11_tensor_map.cpp @@ -139,9 +139,113 @@ static void test_3d() } +static void test_from_tensor() +{ + Tensor mat1(2,3,7); + Tensor mat2(2,3,7); + + int val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + mat1(i,j,k) = val; + mat2(i,j,k) = val; + val++; + } + } + } + + TensorMap> mat3(mat1); + TensorMap> mat4(mat2); + + VERIFY_IS_EQUAL(mat3.rank(), 3); + VERIFY_IS_EQUAL(mat3.size(), 2*3*7); + VERIFY_IS_EQUAL(mat3.dimension(0), 2); + VERIFY_IS_EQUAL(mat3.dimension(1), 3); + VERIFY_IS_EQUAL(mat3.dimension(2), 7); + + VERIFY_IS_EQUAL(mat4.rank(), 3); + VERIFY_IS_EQUAL(mat4.size(), 2*3*7); + VERIFY_IS_EQUAL(mat4.dimension(0), 2); + VERIFY_IS_EQUAL(mat4.dimension(1), 3); + VERIFY_IS_EQUAL(mat4.dimension(2), 7); + + val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(mat3(i,j,k), val); + VERIFY_IS_EQUAL(mat4(i,j,k), val); + val++; + } + } + } + + TensorFixedSize> mat5; + + val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + mat5(i,j,k) = val; + val++; + } + } + } + + TensorMap>> mat6(mat5); + + VERIFY_IS_EQUAL(mat6.rank(), 3); + VERIFY_IS_EQUAL(mat6.size(), 2*3*7); + VERIFY_IS_EQUAL(mat6.dimension(0), 2); + VERIFY_IS_EQUAL(mat6.dimension(1), 3); + VERIFY_IS_EQUAL(mat6.dimension(2), 7); + + val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(mat6(i,j,k), val); + val++; + } + } + } +} + + +static int f(const TensorMap >& tensor) { + Tensor result = tensor.sum(); + return result(0); +} + +static void test_casting() +{ + Tensor tensor(2,3,7); + + int val = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + tensor(i,j,k) = val; + val++; + } + } + } + + TensorMap> map(tensor); + int sum1 = f(map); + int sum2 = f(tensor); + + VERIFY_IS_EQUAL(sum1, sum2); + VERIFY_IS_EQUAL(sum1, 41); +} + void test_cxx11_tensor_map() { CALL_SUBTEST(test_1d()); CALL_SUBTEST(test_2d()); CALL_SUBTEST(test_3d()); + + CALL_SUBTEST(test_from_tensor()); + CALL_SUBTEST(test_casting()); } -- cgit v1.2.3