From e1d28b7ea7ee2aad4603121d5e1bec0c4484c838 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 29 Jul 2015 15:01:21 -0700 Subject: Added a test for shuffling --- unsupported/test/cxx11_tensor_shuffling.cpp | 53 +++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 6 deletions(-) (limited to 'unsupported/test/cxx11_tensor_shuffling.cpp') diff --git a/unsupported/test/cxx11_tensor_shuffling.cpp b/unsupported/test/cxx11_tensor_shuffling.cpp index 2f7fd9e50..d11444a14 100644 --- a/unsupported/test/cxx11_tensor_shuffling.cpp +++ b/unsupported/test/cxx11_tensor_shuffling.cpp @@ -176,12 +176,53 @@ static void test_shuffling_as_value() } } + +template +static void test_shuffle_unshuffle() +{ + Tensor tensor(2,3,5,7); + tensor.setRandom(); + + // Choose a random permutation. + array shuffles; + for (int i = 0; i < 4; ++i) { + shuffles[i] = i; + } + array shuffles_inverse; + for (int i = 0; i < 4; ++i) { + const ptrdiff_t index = internal::random(i, 3); + shuffles_inverse[shuffles[index]] = i; + std::swap(shuffles[i], shuffles[index]); + } + + Tensor shuffle; + shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse); + + VERIFY_IS_EQUAL(shuffle.dimension(0), 2); + VERIFY_IS_EQUAL(shuffle.dimension(1), 3); + VERIFY_IS_EQUAL(shuffle.dimension(2), 5); + VERIFY_IS_EQUAL(shuffle.dimension(3), 7); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l)); + } + } + } + } +} + + void test_cxx11_tensor_shuffling() { - CALL_SUBTEST(test_simple_shuffling()); - CALL_SUBTEST(test_simple_shuffling()); - CALL_SUBTEST(test_expr_shuffling()); - CALL_SUBTEST(test_expr_shuffling()); - CALL_SUBTEST(test_shuffling_as_value()); - CALL_SUBTEST(test_shuffling_as_value()); + CALL_SUBTEST(test_simple_shuffling()); + CALL_SUBTEST(test_simple_shuffling()); + CALL_SUBTEST(test_expr_shuffling()); + CALL_SUBTEST(test_expr_shuffling()); + CALL_SUBTEST(test_shuffling_as_value()); + CALL_SUBTEST(test_shuffling_as_value()); + CALL_SUBTEST(test_shuffle_unshuffle()); + CALL_SUBTEST(test_shuffle_unshuffle()); } -- cgit v1.2.3