aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_thread_pool.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-01-30 10:56:47 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-01-30 10:56:47 -0800
commit9de155d15320a68182e7f572adf504cad6172419 (patch)
tree9d2507c04ce4b939c1e36037b17ed5a5140779fd /unsupported/test/cxx11_tensor_thread_pool.cpp
parent32088c06a169ed8d1286c491ed21a20321ae58a5 (diff)
Added a test to cover threaded tensor shuffling
Diffstat (limited to 'unsupported/test/cxx11_tensor_thread_pool.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_thread_pool.cpp27
1 files changed, 27 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp
index aa76009b7..e46197464 100644
--- a/unsupported/test/cxx11_tensor_thread_pool.cpp
+++ b/unsupported/test/cxx11_tensor_thread_pool.cpp
@@ -283,6 +283,31 @@ void test_multithread_random()
t.device(device) = t.random<Eigen::internal::NormalRandomGenerator<float>>();
}
+template<int DataLayout>
+void test_multithread_shuffle()
+{
+ Tensor<float, 4, DataLayout> tensor(17,5,7,11);
+ tensor.setRandom();
+
+ const int num_threads = internal::random<int>(2, 11);
+ ThreadPool threads(num_threads);
+ Eigen::ThreadPoolDevice device(&threads, num_threads);
+
+ Tensor<float, 4, DataLayout> shuffle(7,5,11,17);
+ array<ptrdiff_t, 4> shuffles = {{2,1,3,0}};
+ shuffle.device(device) = tensor.shuffle(shuffles);
+
+ for (int i = 0; i < 17; ++i) {
+ for (int j = 0; j < 5; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ for (int l = 0; l < 11; ++l) {
+ VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,j,l,i));
+ }
+ }
+ }
+ }
+}
+
void test_cxx11_tensor_thread_pool()
{
@@ -304,4 +329,6 @@ void test_cxx11_tensor_thread_pool()
CALL_SUBTEST_6(test_memcpy());
CALL_SUBTEST_6(test_multithread_random());
+ CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>());
+ CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>());
}