aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_shuffling.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-07-29 15:01:21 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-07-29 15:01:21 -0700
commite1d28b7ea7ee2aad4603121d5e1bec0c4484c838 (patch)
treeefb7046e83074e76eae2fac92c53832251857295 /unsupported/test/cxx11_tensor_shuffling.cpp
parent0570594f2c0c9fd241ff76f741d034e1daf106f9 (diff)
Added a test for shuffling
Diffstat (limited to 'unsupported/test/cxx11_tensor_shuffling.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_shuffling.cpp53
1 files changed, 47 insertions, 6 deletions
diff --git a/unsupported/test/cxx11_tensor_shuffling.cpp b/unsupported/test/cxx11_tensor_shuffling.cpp
index 2f7fd9e50..d11444a14 100644
--- a/unsupported/test/cxx11_tensor_shuffling.cpp
+++ b/unsupported/test/cxx11_tensor_shuffling.cpp
@@ -176,12 +176,53 @@ static void test_shuffling_as_value()
}
}
+
+template <int DataLayout>
+static void test_shuffle_unshuffle()
+{
+ Tensor<float, 4, DataLayout> tensor(2,3,5,7);
+ tensor.setRandom();
+
+ // Choose a random permutation.
+ array<ptrdiff_t, 4> shuffles;
+ for (int i = 0; i < 4; ++i) {
+ shuffles[i] = i;
+ }
+ array<ptrdiff_t, 4> shuffles_inverse;
+ for (int i = 0; i < 4; ++i) {
+ const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3);
+ shuffles_inverse[shuffles[index]] = i;
+ std::swap(shuffles[i], shuffles[index]);
+ }
+
+ Tensor<float, 4, DataLayout> shuffle;
+ shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse);
+
+ VERIFY_IS_EQUAL(shuffle.dimension(0), 2);
+ VERIFY_IS_EQUAL(shuffle.dimension(1), 3);
+ VERIFY_IS_EQUAL(shuffle.dimension(2), 5);
+ VERIFY_IS_EQUAL(shuffle.dimension(3), 7);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 5; ++k) {
+ for (int l = 0; l < 7; ++l) {
+ VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l));
+ }
+ }
+ }
+ }
+}
+
+
void test_cxx11_tensor_shuffling()
{
- CALL_SUBTEST(test_simple_shuffling<ColMajor>());
- CALL_SUBTEST(test_simple_shuffling<RowMajor>());
- CALL_SUBTEST(test_expr_shuffling<ColMajor>());
- CALL_SUBTEST(test_expr_shuffling<RowMajor>());
- CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
- CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
+ CALL_SUBTEST(test_simple_shuffling<ColMajor>());
+ CALL_SUBTEST(test_simple_shuffling<RowMajor>());
+ CALL_SUBTEST(test_expr_shuffling<ColMajor>());
+ CALL_SUBTEST(test_expr_shuffling<RowMajor>());
+ CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
+ CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
+ CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>());
+ CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>());
}