diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-30 13:08:12 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-06-30 13:08:12 -0700 |
commit | 109005c6c9a3e6e1e42ff9efc71be3a899519ed4 (patch) | |
tree | 7f04bb79861158191692a223f282dffe73e46f03 /unsupported/test | |
parent | a4aa7c62177ec333b91e186b10abff3bbb573077 (diff) |
Added a test for multithreaded full reductions
Diffstat (limited to 'unsupported/test')
-rw-r--r-- | unsupported/test/cxx11_tensor_thread_pool.cpp | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index 0a20a01a4..5ec7c8bf4 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -228,6 +228,29 @@ static void test_multithread_contraction_agrees_with_singlethread() { } +template<int DataLayout> +static void test_multithreaded_reductions() { + const int num_threads = internal::random<int>(3, 11); + ThreadPool thread_pool(num_threads); + Eigen::ThreadPoolDevice thread_pool_device(&thread_pool, num_threads); + + const int num_rows = internal::random<int>(13, 732); + const int num_cols = internal::random<int>(13, 732); + Tensor<float, 2, DataLayout> t1(num_rows, num_cols); + t1.setRandom(); + + Tensor<float, 1, DataLayout> full_redux(1); + full_redux = t1.sum(); + + Tensor<float, 1, DataLayout> full_redux_tp(1); + full_redux_tp.device(thread_pool_device) = t1.sum(); + + // Check that the single threaded and the multi threaded reductions return + // the same result. + VERIFY_IS_APPROX(full_redux(0), full_redux_tp(0)); +} + + static void test_memcpy() { for (int i = 0; i < 5; ++i) { @@ -271,6 +294,9 @@ void test_cxx11_tensor_thread_pool() CALL_SUBTEST(test_contraction_corner_cases<ColMajor>()); CALL_SUBTEST(test_contraction_corner_cases<RowMajor>()); + CALL_SUBTEST(test_multithreaded_reductions<ColMajor>()); + CALL_SUBTEST(test_multithreaded_reductions<RowMajor>()); + CALL_SUBTEST(test_memcpy()); CALL_SUBTEST(test_multithread_random()); |