From d4afccde5a9553ddfb48b0f5fad0115cd8bf791a Mon Sep 17 00:00:00 2001 From: Paul Tucker Date: Thu, 19 Jul 2018 17:43:44 -0700 Subject: Add test coverage for ThreadPoolDevice optional allocator. --- unsupported/test/cxx11_tensor_thread_pool.cpp | 45 ++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 4 deletions(-) (limited to 'unsupported/test/cxx11_tensor_thread_pool.cpp') diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index 2ef665f30..200664740 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -16,6 +16,25 @@ using Eigen::Tensor; +class TestAllocator : public Allocator { + public: + ~TestAllocator() override {} + EIGEN_DEVICE_FUNC void* allocate(size_t num_bytes) const override { + const_cast(this)->alloc_count_++; + return internal::aligned_malloc(num_bytes); + } + EIGEN_DEVICE_FUNC void deallocate(void* buffer) const override { + const_cast(this)->dealloc_count_++; + internal::aligned_free(buffer); + } + + int alloc_count() const { return alloc_count_; } + int dealloc_count() const { return dealloc_count_; } + + private: + int alloc_count_ = 0; + int dealloc_count_ = 0; +}; void test_multithread_elementwise() { @@ -320,14 +339,14 @@ void test_multithread_random() } template -void test_multithread_shuffle() +void test_multithread_shuffle(Allocator* allocator) { Tensor tensor(17,5,7,11); tensor.setRandom(); const int num_threads = internal::random(2, 11); ThreadPool threads(num_threads); - Eigen::ThreadPoolDevice device(&threads, num_threads); + Eigen::ThreadPoolDevice device(&threads, num_threads, allocator); Tensor shuffle(7,5,11,17); array shuffles = {{2,1,3,0}}; @@ -344,6 +363,21 @@ void test_multithread_shuffle() } } +void test_threadpool_allocate(TestAllocator* allocator) +{ + const int num_threads = internal::random(2, 11); + const int num_allocs = internal::random(2, 11); + ThreadPool threads(num_threads); + Eigen::ThreadPoolDevice device(&threads, num_threads, allocator); + + for (int a = 0; a < num_allocs; ++a) { + void* ptr = device.allocate(512); + device.deallocate(ptr); + } + VERIFY(allocator != nullptr); + VERIFY_IS_EQUAL(allocator->alloc_count(), num_allocs); + VERIFY_IS_EQUAL(allocator->dealloc_count(), num_allocs); +} void test_cxx11_tensor_thread_pool() { @@ -368,6 +402,9 @@ void test_cxx11_tensor_thread_pool() CALL_SUBTEST_6(test_memcpy()); CALL_SUBTEST_6(test_multithread_random()); - CALL_SUBTEST_6(test_multithread_shuffle()); - CALL_SUBTEST_6(test_multithread_shuffle()); + + TestAllocator test_allocator; + CALL_SUBTEST_6(test_multithread_shuffle(nullptr)); + CALL_SUBTEST_6(test_multithread_shuffle(&test_allocator)); + CALL_SUBTEST_6(test_threadpool_allocate(&test_allocator)); } -- cgit v1.2.3