aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_thread_pool.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/test/cxx11_tensor_thread_pool.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_thread_pool.cpp85
1 files changed, 58 insertions, 27 deletions
diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp
index e28cf55e2..e46197464 100644
--- a/unsupported/test/cxx11_tensor_thread_pool.cpp
+++ b/unsupported/test/cxx11_tensor_thread_pool.cpp
@@ -17,7 +17,7 @@
using Eigen::Tensor;
-static void test_multithread_elementwise()
+void test_multithread_elementwise()
{
Tensor<float, 3> in1(2,3,7);
Tensor<float, 3> in2(2,3,7);
@@ -40,7 +40,7 @@ static void test_multithread_elementwise()
}
-static void test_multithread_compound_assignment()
+void test_multithread_compound_assignment()
{
Tensor<float, 3> in1(2,3,7);
Tensor<float, 3> in2(2,3,7);
@@ -64,7 +64,7 @@ static void test_multithread_compound_assignment()
}
template<int DataLayout>
-static void test_multithread_contraction()
+void test_multithread_contraction()
{
Tensor<float, 4, DataLayout> t_left(30, 50, 37, 31);
Tensor<float, 5, DataLayout> t_right(37, 31, 70, 2, 10);
@@ -91,15 +91,20 @@ static void test_multithread_contraction()
for (ptrdiff_t i = 0; i < t_result.size(); i++) {
VERIFY(&t_result.data()[i] != &m_result.data()[i]);
- if (fabs(t_result.data()[i] - m_result.data()[i]) >= 1e-4) {
- std::cout << "mismatch detected: " << t_result.data()[i] << " vs " << m_result.data()[i] << std::endl;
- assert(false);
+ if (fabs(t_result(i) - m_result(i)) < 1e-4) {
+ continue;
+ }
+ if (Eigen::internal::isApprox(t_result(i), m_result(i), 1e-4f)) {
+ continue;
}
+ std::cout << "mismatch detected at index " << i << ": " << t_result(i)
+ << " vs " << m_result(i) << std::endl;
+ assert(false);
}
}
template<int DataLayout>
-static void test_contraction_corner_cases()
+void test_contraction_corner_cases()
{
Tensor<float, 2, DataLayout> t_left(32, 500);
Tensor<float, 2, DataLayout> t_right(32, 28*28);
@@ -186,7 +191,7 @@ static void test_contraction_corner_cases()
}
template<int DataLayout>
-static void test_multithread_contraction_agrees_with_singlethread() {
+void test_multithread_contraction_agrees_with_singlethread() {
int contract_size = internal::random<int>(1, 5000);
Tensor<float, 3, DataLayout> left(internal::random<int>(1, 80),
@@ -229,7 +234,7 @@ static void test_multithread_contraction_agrees_with_singlethread() {
template<int DataLayout>
-static void test_multithreaded_reductions() {
+void test_multithreaded_reductions() {
const int num_threads = internal::random<int>(3, 11);
ThreadPool thread_pool(num_threads);
Eigen::ThreadPoolDevice thread_pool_device(&thread_pool, num_threads);
@@ -239,19 +244,19 @@ static void test_multithreaded_reductions() {
Tensor<float, 2, DataLayout> t1(num_rows, num_cols);
t1.setRandom();
- Tensor<float, 1, DataLayout> full_redux(1);
+ Tensor<float, 0, DataLayout> full_redux;
full_redux = t1.sum();
- Tensor<float, 1, DataLayout> full_redux_tp(1);
+ Tensor<float, 0, DataLayout> full_redux_tp;
full_redux_tp.device(thread_pool_device) = t1.sum();
// Check that the single threaded and the multi threaded reductions return
// the same result.
- VERIFY_IS_APPROX(full_redux(0), full_redux_tp(0));
+ VERIFY_IS_APPROX(full_redux(), full_redux_tp());
}
-static void test_memcpy() {
+void test_memcpy() {
for (int i = 0; i < 5; ++i) {
const int num_threads = internal::random<int>(3, 11);
@@ -270,7 +275,7 @@ static void test_memcpy() {
}
-static void test_multithread_random()
+void test_multithread_random()
{
Eigen::ThreadPool tp(2);
Eigen::ThreadPoolDevice device(&tp, 2);
@@ -278,26 +283,52 @@ static void test_multithread_random()
t.device(device) = t.random<Eigen::internal::NormalRandomGenerator<float>>();
}
+template<int DataLayout>
+void test_multithread_shuffle()
+{
+ Tensor<float, 4, DataLayout> tensor(17,5,7,11);
+ tensor.setRandom();
+
+ const int num_threads = internal::random<int>(2, 11);
+ ThreadPool threads(num_threads);
+ Eigen::ThreadPoolDevice device(&threads, num_threads);
+
+ Tensor<float, 4, DataLayout> shuffle(7,5,11,17);
+ array<ptrdiff_t, 4> shuffles = {{2,1,3,0}};
+ shuffle.device(device) = tensor.shuffle(shuffles);
+
+ for (int i = 0; i < 17; ++i) {
+ for (int j = 0; j < 5; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ for (int l = 0; l < 11; ++l) {
+ VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,j,l,i));
+ }
+ }
+ }
+ }
+}
+
void test_cxx11_tensor_thread_pool()
{
- CALL_SUBTEST(test_multithread_elementwise());
- CALL_SUBTEST(test_multithread_compound_assignment());
+ CALL_SUBTEST_1(test_multithread_elementwise());
+ CALL_SUBTEST_1(test_multithread_compound_assignment());
- CALL_SUBTEST(test_multithread_contraction<ColMajor>());
- CALL_SUBTEST(test_multithread_contraction<RowMajor>());
+ CALL_SUBTEST_2(test_multithread_contraction<ColMajor>());
+ CALL_SUBTEST_2(test_multithread_contraction<RowMajor>());
- CALL_SUBTEST(test_multithread_contraction_agrees_with_singlethread<ColMajor>());
- CALL_SUBTEST(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
+ CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>());
+ CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
// Exercise various cases that have been problematic in the past.
- CALL_SUBTEST(test_contraction_corner_cases<ColMajor>());
- CALL_SUBTEST(test_contraction_corner_cases<RowMajor>());
-
- CALL_SUBTEST(test_multithreaded_reductions<ColMajor>());
- CALL_SUBTEST(test_multithreaded_reductions<RowMajor>());
+ CALL_SUBTEST_4(test_contraction_corner_cases<ColMajor>());
+ CALL_SUBTEST_4(test_contraction_corner_cases<RowMajor>());
- CALL_SUBTEST(test_memcpy());
+ CALL_SUBTEST_5(test_multithreaded_reductions<ColMajor>());
+ CALL_SUBTEST_5(test_multithreaded_reductions<RowMajor>());
- CALL_SUBTEST(test_multithread_random());
+ CALL_SUBTEST_6(test_memcpy());
+ CALL_SUBTEST_6(test_multithread_random());
+ CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>());
+ CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>());
}