diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_thread_pool.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_thread_pool.cpp | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index aa76009b7..e46197464 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -283,6 +283,31 @@ void test_multithread_random() t.device(device) = t.random<Eigen::internal::NormalRandomGenerator<float>>(); } +template<int DataLayout> +void test_multithread_shuffle() +{ + Tensor<float, 4, DataLayout> tensor(17,5,7,11); + tensor.setRandom(); + + const int num_threads = internal::random<int>(2, 11); + ThreadPool threads(num_threads); + Eigen::ThreadPoolDevice device(&threads, num_threads); + + Tensor<float, 4, DataLayout> shuffle(7,5,11,17); + array<ptrdiff_t, 4> shuffles = {{2,1,3,0}}; + shuffle.device(device) = tensor.shuffle(shuffles); + + for (int i = 0; i < 17; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + for (int l = 0; l < 11; ++l) { + VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,j,l,i)); + } + } + } + } +} + void test_cxx11_tensor_thread_pool() { @@ -304,4 +329,6 @@ void test_cxx11_tensor_thread_pool() CALL_SUBTEST_6(test_memcpy()); CALL_SUBTEST_6(test_multithread_random()); + CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>()); + CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>()); } |