From f363e533aac5aac0d67fd5728b2e5b509c756bc8 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 5 May 2016 09:05:45 -0700 Subject: Added tests for full contractions using thread pools and gpu devices. Fixed a couple of issues in the corresponding code. --- unsupported/test/cxx11_tensor_thread_pool.cpp | 39 +++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) (limited to 'unsupported/test/cxx11_tensor_thread_pool.cpp') diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index e46197464..5fd3f0bf1 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -233,6 +233,42 @@ void test_multithread_contraction_agrees_with_singlethread() { } +template +void test_full_contraction() { + int contract_size1 = internal::random(1, 500); + int contract_size2 = internal::random(1, 500); + + Tensor left(contract_size1, + contract_size2); + Tensor right(contract_size1, + contract_size2); + left.setRandom(); + right.setRandom(); + + // add constants to shift values away from 0 for more precision + left += left.constant(1.5f); + right += right.constant(1.5f); + + typedef Tensor::DimensionPair DimPair; + Eigen::array dims({{DimPair(0, 0), DimPair(1, 1)}}); + + Eigen::ThreadPool tp(internal::random(2, 11)); + Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random(2, 11)); + + Tensor st_result; + st_result = left.contract(right, dims); + + Tensor tp_result; + tp_result.device(thread_pool_device) = left.contract(right, dims); + + VERIFY(dimensions_match(st_result.dimensions(), tp_result.dimensions())); + // if both of the values are very small, then do nothing (because the test will fail + // due to numerical precision issues when values are small) + if (fabs(st_result() - tp_result()) >= 1e-4) { + VERIFY_IS_APPROX(st_result(), tp_result()); + } +} + template void test_multithreaded_reductions() { const int num_threads = internal::random(3, 11); @@ -324,6 +360,9 @@ void test_cxx11_tensor_thread_pool() CALL_SUBTEST_4(test_contraction_corner_cases()); CALL_SUBTEST_4(test_contraction_corner_cases()); + CALL_SUBTEST_4(test_full_contraction()); + CALL_SUBTEST_4(test_full_contraction()); + CALL_SUBTEST_5(test_multithreaded_reductions()); CALL_SUBTEST_5(test_multithreaded_reductions()); -- cgit v1.2.3