diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h | 7 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_thread_pool.cpp | 45 |
2 files changed, 48 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h index c9534d400..f4123b71d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h @@ -91,6 +91,13 @@ static EIGEN_STRONG_INLINE void wait_until_ready(SyncType* n) { } } +// An abstract interface to a device specific memory allocator. +class Allocator { + public: + virtual ~Allocator() {} + EIGEN_DEVICE_FUNC virtual void* allocate(size_t num_bytes) const = 0; + EIGEN_DEVICE_FUNC virtual void deallocate(void* buffer) const = 0; +}; // Build a thread pool device on top the an existing pool of threads. struct ThreadPoolDevice { diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index 2ef665f30..200664740 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -16,6 +16,25 @@ using Eigen::Tensor; +class TestAllocator : public Allocator { + public: + ~TestAllocator() override {} + EIGEN_DEVICE_FUNC void* allocate(size_t num_bytes) const override { + const_cast<TestAllocator*>(this)->alloc_count_++; + return internal::aligned_malloc(num_bytes); + } + EIGEN_DEVICE_FUNC void deallocate(void* buffer) const override { + const_cast<TestAllocator*>(this)->dealloc_count_++; + internal::aligned_free(buffer); + } + + int alloc_count() const { return alloc_count_; } + int dealloc_count() const { return dealloc_count_; } + + private: + int alloc_count_ = 0; + int dealloc_count_ = 0; +}; void test_multithread_elementwise() { @@ -320,14 +339,14 @@ void test_multithread_random() } template<int DataLayout> -void test_multithread_shuffle() +void test_multithread_shuffle(Allocator* allocator) { Tensor<float, 4, DataLayout> tensor(17,5,7,11); tensor.setRandom(); const int num_threads = internal::random<int>(2, 11); ThreadPool threads(num_threads); - Eigen::ThreadPoolDevice device(&threads, num_threads); + Eigen::ThreadPoolDevice device(&threads, num_threads, allocator); Tensor<float, 4, DataLayout> shuffle(7,5,11,17); array<ptrdiff_t, 4> shuffles = {{2,1,3,0}}; @@ -344,6 +363,21 @@ void test_multithread_shuffle() } } +void test_threadpool_allocate(TestAllocator* allocator) +{ + const int num_threads = internal::random<int>(2, 11); + const int num_allocs = internal::random<int>(2, 11); + ThreadPool threads(num_threads); + Eigen::ThreadPoolDevice device(&threads, num_threads, allocator); + + for (int a = 0; a < num_allocs; ++a) { + void* ptr = device.allocate(512); + device.deallocate(ptr); + } + VERIFY(allocator != nullptr); + VERIFY_IS_EQUAL(allocator->alloc_count(), num_allocs); + VERIFY_IS_EQUAL(allocator->dealloc_count(), num_allocs); +} void test_cxx11_tensor_thread_pool() { @@ -368,6 +402,9 @@ void test_cxx11_tensor_thread_pool() CALL_SUBTEST_6(test_memcpy()); CALL_SUBTEST_6(test_multithread_random()); - CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>()); - CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>()); + + TestAllocator test_allocator; + CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>(nullptr)); + CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>(&test_allocator)); + CALL_SUBTEST_6(test_threadpool_allocate(&test_allocator)); } |