aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_thread_pool.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-06-30 13:08:12 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-06-30 13:08:12 -0700
commit109005c6c9a3e6e1e42ff9efc71be3a899519ed4 (patch)
tree7f04bb79861158191692a223f282dffe73e46f03 /unsupported/test/cxx11_tensor_thread_pool.cpp
parenta4aa7c62177ec333b91e186b10abff3bbb573077 (diff)
Added a test for multithreaded full reductions
Diffstat (limited to 'unsupported/test/cxx11_tensor_thread_pool.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_thread_pool.cpp26
1 files changed, 26 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp
index 0a20a01a4..5ec7c8bf4 100644
--- a/unsupported/test/cxx11_tensor_thread_pool.cpp
+++ b/unsupported/test/cxx11_tensor_thread_pool.cpp
@@ -228,6 +228,29 @@ static void test_multithread_contraction_agrees_with_singlethread() {
}
+template<int DataLayout>
+static void test_multithreaded_reductions() {
+ const int num_threads = internal::random<int>(3, 11);
+ ThreadPool thread_pool(num_threads);
+ Eigen::ThreadPoolDevice thread_pool_device(&thread_pool, num_threads);
+
+ const int num_rows = internal::random<int>(13, 732);
+ const int num_cols = internal::random<int>(13, 732);
+ Tensor<float, 2, DataLayout> t1(num_rows, num_cols);
+ t1.setRandom();
+
+ Tensor<float, 1, DataLayout> full_redux(1);
+ full_redux = t1.sum();
+
+ Tensor<float, 1, DataLayout> full_redux_tp(1);
+ full_redux_tp.device(thread_pool_device) = t1.sum();
+
+ // Check that the single threaded and the multi threaded reductions return
+ // the same result.
+ VERIFY_IS_APPROX(full_redux(0), full_redux_tp(0));
+}
+
+
static void test_memcpy() {
for (int i = 0; i < 5; ++i) {
@@ -271,6 +294,9 @@ void test_cxx11_tensor_thread_pool()
CALL_SUBTEST(test_contraction_corner_cases<ColMajor>());
CALL_SUBTEST(test_contraction_corner_cases<RowMajor>());
+ CALL_SUBTEST(test_multithreaded_reductions<ColMajor>());
+ CALL_SUBTEST(test_multithreaded_reductions<RowMajor>());
+
CALL_SUBTEST(test_memcpy());
CALL_SUBTEST(test_multithread_random());