diff options
Diffstat (limited to 'unsupported/test')
-rw-r--r-- | unsupported/test/cxx11_tensor_shuffling.cpp | 53 |
1 files changed, 47 insertions, 6 deletions
diff --git a/unsupported/test/cxx11_tensor_shuffling.cpp b/unsupported/test/cxx11_tensor_shuffling.cpp index 2f7fd9e50..d11444a14 100644 --- a/unsupported/test/cxx11_tensor_shuffling.cpp +++ b/unsupported/test/cxx11_tensor_shuffling.cpp @@ -176,12 +176,53 @@ static void test_shuffling_as_value() } } + +template <int DataLayout> +static void test_shuffle_unshuffle() +{ + Tensor<float, 4, DataLayout> tensor(2,3,5,7); + tensor.setRandom(); + + // Choose a random permutation. + array<ptrdiff_t, 4> shuffles; + for (int i = 0; i < 4; ++i) { + shuffles[i] = i; + } + array<ptrdiff_t, 4> shuffles_inverse; + for (int i = 0; i < 4; ++i) { + const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3); + shuffles_inverse[shuffles[index]] = i; + std::swap(shuffles[i], shuffles[index]); + } + + Tensor<float, 4, DataLayout> shuffle; + shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse); + + VERIFY_IS_EQUAL(shuffle.dimension(0), 2); + VERIFY_IS_EQUAL(shuffle.dimension(1), 3); + VERIFY_IS_EQUAL(shuffle.dimension(2), 5); + VERIFY_IS_EQUAL(shuffle.dimension(3), 7); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l)); + } + } + } + } +} + + void test_cxx11_tensor_shuffling() { - CALL_SUBTEST(test_simple_shuffling<ColMajor>()); - CALL_SUBTEST(test_simple_shuffling<RowMajor>()); - CALL_SUBTEST(test_expr_shuffling<ColMajor>()); - CALL_SUBTEST(test_expr_shuffling<RowMajor>()); - CALL_SUBTEST(test_shuffling_as_value<ColMajor>()); - CALL_SUBTEST(test_shuffling_as_value<RowMajor>()); + CALL_SUBTEST(test_simple_shuffling<ColMajor>()); + CALL_SUBTEST(test_simple_shuffling<RowMajor>()); + CALL_SUBTEST(test_expr_shuffling<ColMajor>()); + CALL_SUBTEST(test_expr_shuffling<RowMajor>()); + CALL_SUBTEST(test_shuffling_as_value<ColMajor>()); + CALL_SUBTEST(test_shuffling_as_value<RowMajor>()); + CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>()); + CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>()); } |