diff options
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/ThreadPool | 10 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/ThreadPool/ThreadLocal.h | 222 | ||||
-rw-r--r-- | unsupported/test/CMakeLists.txt | 1 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_thread_local.cpp | 158 |
4 files changed, 385 insertions, 6 deletions
diff --git a/unsupported/Eigen/CXX11/ThreadPool b/unsupported/Eigen/CXX11/ThreadPool index 7a795da3d..71a6afe39 100644 --- a/unsupported/Eigen/CXX11/ThreadPool +++ b/unsupported/Eigen/CXX11/ThreadPool @@ -45,11 +45,7 @@ #include <functional> #include <memory> #include <utility> -#include "src/util/CXX11Meta.h" -#include "src/util/MaxSizeVector.h" -#include "src/ThreadPool/ThreadLocal.h" -#ifndef EIGEN_THREAD_LOCAL // There are non-parenthesized calls to "max" in the <unordered_map> header, // which trigger a check in test/main.h causing compilation to fail. // We work around the check here by removing the check for max in @@ -58,7 +54,11 @@ #undef max #endif #include <unordered_map> -#endif + +#include "src/util/CXX11Meta.h" +#include "src/util/MaxSizeVector.h" + +#include "src/ThreadPool/ThreadLocal.h" #include "src/ThreadPool/ThreadYield.h" #include "src/ThreadPool/ThreadCancel.h" #include "src/ThreadPool/EventCount.h" diff --git a/unsupported/Eigen/CXX11/src/ThreadPool/ThreadLocal.h b/unsupported/Eigen/CXX11/src/ThreadPool/ThreadLocal.h index 696c2d03b..63a168372 100644 --- a/unsupported/Eigen/CXX11/src/ThreadPool/ThreadLocal.h +++ b/unsupported/Eigen/CXX11/src/ThreadPool/ThreadLocal.h @@ -60,6 +60,226 @@ #endif #endif // defined(__ANDROID__) && defined(__clang__) -#endif // EIGEN_AVOID_THREAD_LOCAL +#endif // EIGEN_AVOID_THREAD_LOCAL + +namespace Eigen { + +// Thread local container for elements of type Factory::T, that does not use +// thread local storage. It will lazily initialize elements for each thread that +// accesses this object. As long as the number of unique threads accessing this +// storage is smaller than `kAllocationMultiplier * num_threads`, it is +// lock-free and wait-free. Otherwise it will use a mutex for synchronization. +// +// Example: +// +// struct Counter { +// int value; +// } +// +// struct CounterFactory { +// using T = Counter; +// +// Counter Allocate() { return {0}; } +// void Release(Counter&) {} +// }; +// +// CounterFactory factory; +// Eigen::ThreadLocal<CounterFactory> counter(factory, 10); +// +// // Each thread will have access to it's own counter object. +// Counter& cnt = counter.local(); +// cnt++; +// +// WARNING: Eigen::ThreadLocal uses the OS-specific value returned by +// std::this_thread::get_id() to identify threads. This value is not guaranteed +// to be unique except for the life of the thread. A newly created thread may +// get an OS-specific ID equal to that of an already destroyed thread. +// +// Somewhat similar to TBB thread local storage, with similar restrictions: +// https://www.threadingbuildingblocks.org/docs/help/reference/thread_local_storage/enumerable_thread_specific_cls.html +// +template<typename Factory> +class ThreadLocal { + // We allocate larger storage for thread local data, than the number of + // threads, because thread pool size might grow, or threads outside of a + // thread pool might steal the work. We still expect this number to be of the + // same order of magnitude as the original `num_threads`. + static constexpr int kAllocationMultiplier = 4; + + using T = typename Factory::T; + + // We preallocate default constructed elements in MaxSizedVector. + static_assert(std::is_default_constructible<T>::value, + "ThreadLocal data type must be default constructible"); + + public: + explicit ThreadLocal(Factory& factory, int num_threads) + : factory_(factory), + num_records_(kAllocationMultiplier * num_threads), + data_(num_records_), + ptr_(num_records_), + filled_records_(0) { + eigen_assert(num_threads >= 0); + data_.resize(num_records_); + for (int i = 0; i < num_records_; ++i) { + ptr_.emplace_back(nullptr); + } + } + + T& local() { + std::thread::id this_thread = std::this_thread::get_id(); + if (num_records_ == 0) return SpilledLocal(this_thread); + + std::size_t h = std::hash<std::thread::id>()(this_thread); + const int start_idx = h % num_records_; + + // NOTE: From the definition of `std::this_thread::get_id()` it is + // guaranteed that we never can have concurrent insertions with the same key + // to our hash-map like data structure. If we didn't find an element during + // the initial traversal, it's guaranteed that no one else could have + // inserted it while we are in this function. This allows to massively + // simplify out lock-free insert-only hash map. + + // Check if we already have an element for `this_thread`. + int idx = start_idx; + while (ptr_[idx].load() != nullptr) { + ThreadIdAndValue& record = *(ptr_[idx].load()); + if (record.thread_id == this_thread) return record.value; + + idx += 1; + if (idx >= num_records_) idx -= num_records_; + if (idx == start_idx) break; + } + + // If we are here, it means that we found an insertion point in lookup + // table at `idx`, or we did a full traversal and table is full. + + // If lock-free storage is full, fallback on mutex. + if (filled_records_.load() >= num_records_) + return SpilledLocal(this_thread); + + // We double check that we still have space to insert an element into a lock + // free storage. If old value in `filled_records_` is larger than the + // records capacity, it means that some other thread added an element while + // we were traversing lookup table. + int insertion_index = + filled_records_.fetch_add(1, std::memory_order_relaxed); + if (insertion_index >= num_records_) return SpilledLocal(this_thread); + + // At this point it's guaranteed that we can access to + // data_[insertion_index_] without a data race. + data_[insertion_index] = {this_thread, factory_.Allocate()}; + + // That's the pointer we'll put into the lookup table. + ThreadIdAndValue* inserted = &data_[insertion_index]; + + // We'll use nullptr pointer to ThreadIdAndValue in a compare-and-swap loop. + ThreadIdAndValue* empty = nullptr; + + // Now we have to find an insertion point into the lookup table. We start + // from the `idx` that was identified as an insertion point above, it's + // guaranteed that we will have an empty record somewhere in a lookup table + // (because we created a record in the `data_`). + const int insertion_idx = idx; + + do { + // Always start search from the original insertion candidate. + idx = insertion_idx; + while (ptr_[idx].load() != nullptr) { + idx += 1; + if (idx >= num_records_) idx -= num_records_; + // If we did a full loop, it means that we don't have any free entries + // in the lookup table, and this means that something is terribly wrong. + eigen_assert(idx != insertion_idx); + } + // Atomic CAS of the pointer guarantees that any other thread, that will + // follow this pointer will see all the mutations in the `data_`. + } while (!ptr_[idx].compare_exchange_weak(empty, inserted)); + + return inserted->value; + } + + // WARN: It's not thread safe to call it concurrently with `local()`. + void ForEach(std::function<void(std::thread::id, T & )> f) { + // Reading directly from `data_` is unsafe, because only CAS to the + // record in `ptr_` makes all changes visible to other threads. + for (auto& ptr : ptr_) { + ThreadIdAndValue* record = ptr.load(); + if (record == nullptr) continue; + f(record->thread_id, record->value); + } + + // We did not spill into the map based storage. + if (filled_records_.load(std::memory_order_relaxed) < num_records_) return; + + // Adds a happens before edge from the last call to SpilledLocal(). + std::unique_lock<std::mutex> lock(mu_); + for (auto& kv : per_thread_map_) { + f(kv.first, kv.second); + } + } + + // WARN: It's not thread safe to call it concurrently with `local()`. + ~ThreadLocal() { + // Reading directly from `data_` is unsafe, because only CAS to the record + // in `ptr_` makes all changes visible to other threads. + for (auto& ptr : ptr_) { + ThreadIdAndValue* record = ptr.load(); + if (record == nullptr) continue; + factory_.Release(record->value); + } + + // We did not spill into the map based storage. + if (filled_records_.load(std::memory_order_relaxed) < num_records_) return; + + // Adds a happens before edge from the last call to SpilledLocal(). + std::unique_lock<std::mutex> lock(mu_); + for (auto& kv : per_thread_map_) { + factory_.Release(kv.second); + } + } + + private: + struct ThreadIdAndValue { + std::thread::id thread_id; + T value; + }; + + // Use unordered map guarded by a mutex when lock free storage is full. + T& SpilledLocal(std::thread::id this_thread) { + std::unique_lock<std::mutex> lock(mu_); + + auto it = per_thread_map_.find(this_thread); + if (it == per_thread_map_.end()) { + auto result = per_thread_map_.emplace(this_thread, factory_.Allocate()); + eigen_assert(result.second); + return (*result.first).second; + } else { + return it->second; + } + } + + Factory& factory_; + const int num_records_; + + // Storage that backs lock-free lookup table `ptr_`. Records stored in this + // storage contiguously starting from index 0. + MaxSizeVector<ThreadIdAndValue> data_; + + // Atomic pointers to the data stored in `data_`. Used as a lookup table for + // linear probing hash map (https://en.wikipedia.org/wiki/Linear_probing). + MaxSizeVector<std::atomic<ThreadIdAndValue*>> ptr_; + + // Number of records stored in the `data_`. + std::atomic<int> filled_records_; + + // We fallback on per thread map if lock-free storage is full. In practice + // this should never happen, if `num_threads` is a reasonable estimate of the + // number of threads running in a system. + std::mutex mu_; // Protects per_thread_map_. + std::unordered_map<std::thread::id, T> per_thread_map_; +}; + +} // namespace Eigen #endif // EIGEN_CXX11_THREADPOOL_THREAD_LOCAL_H diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt index 42a450a85..f1f109ecb 100644 --- a/unsupported/test/CMakeLists.txt +++ b/unsupported/test/CMakeLists.txt @@ -201,6 +201,7 @@ if(EIGEN_TEST_CXX11) ei_add_test(cxx11_tensor_shuffling) ei_add_test(cxx11_tensor_striding) ei_add_test(cxx11_tensor_notification "-pthread" "${CMAKE_THREAD_LIBS_INIT}") + ei_add_test(cxx11_tensor_thread_local "-pthread" "${CMAKE_THREAD_LIBS_INIT}") ei_add_test(cxx11_tensor_thread_pool "-pthread" "${CMAKE_THREAD_LIBS_INIT}") ei_add_test(cxx11_tensor_executor "-pthread" "${CMAKE_THREAD_LIBS_INIT}") ei_add_test(cxx11_tensor_ref) diff --git a/unsupported/test/cxx11_tensor_thread_local.cpp b/unsupported/test/cxx11_tensor_thread_local.cpp new file mode 100644 index 000000000..dd43ab9d1 --- /dev/null +++ b/unsupported/test/cxx11_tensor_thread_local.cpp @@ -0,0 +1,158 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#define EIGEN_USE_THREADS + +#include <iostream> +#include <unordered_set> + +#include "main.h" +#include <Eigen/CXX11/ThreadPool> + +class Counter { + public: + Counter() : Counter(0) {} + explicit Counter(int value) + : created_by_(std::this_thread::get_id()), value_(value) {} + + void inc() { + // Check that mutation happens only in a thread that created this counter. + VERIFY_IS_EQUAL(std::this_thread::get_id(), created_by_); + value_++; + } + int value() { return value_; } + + private: + std::thread::id created_by_; + int value_; +}; + +struct CounterFactory { + using T = Counter; + + T Allocate() { return Counter(0); } + void Release(T&) {} +}; + +void test_simple_thread_local() { + CounterFactory factory; + int num_threads = internal::random<int>(4, 32); + Eigen::ThreadPool thread_pool(num_threads); + Eigen::ThreadLocal<CounterFactory> counter(factory, num_threads); + + int num_tasks = 3 * num_threads; + Eigen::Barrier barrier(num_tasks); + + for (int i = 0; i < num_tasks; ++i) { + thread_pool.Schedule([&counter, &barrier]() { + Counter& local = counter.local(); + local.inc(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + barrier.Notify(); + }); + } + + barrier.Wait(); + + counter.ForEach( + [](std::thread::id, Counter& cnt) { VERIFY_IS_EQUAL(cnt.value(), 3); }); +} + +void test_zero_sized_thread_local() { + CounterFactory factory; + Eigen::ThreadLocal<CounterFactory> counter(factory, 0); + + Counter& local = counter.local(); + local.inc(); + + int total = 0; + counter.ForEach([&total](std::thread::id, Counter& cnt) { + total += cnt.value(); + VERIFY_IS_EQUAL(cnt.value(), 1); + }); + + VERIFY_IS_EQUAL(total, 1); +} + +// All thread local values fits into the lock-free storage. +void test_large_number_of_tasks_no_spill() { + CounterFactory factory; + int num_threads = internal::random<int>(4, 32); + Eigen::ThreadPool thread_pool(num_threads); + Eigen::ThreadLocal<CounterFactory> counter(factory, num_threads); + + int num_tasks = 10000; + Eigen::Barrier barrier(num_tasks); + + for (int i = 0; i < num_tasks; ++i) { + thread_pool.Schedule([&counter, &barrier]() { + Counter& local = counter.local(); + local.inc(); + barrier.Notify(); + }); + } + + barrier.Wait(); + + int total = 0; + std::unordered_set<std::thread::id> unique_threads; + + counter.ForEach([&](std::thread::id id, Counter& cnt) { + total += cnt.value(); + unique_threads.insert(id); + }); + + VERIFY_IS_EQUAL(total, num_tasks); + // Not all threads in a pool might be woken up to execute submitted tasks. + // Also thread_pool.Schedule() might use current thread if queue is full. + VERIFY_IS_EQUAL( + unique_threads.size() <= (static_cast<size_t>(num_threads + 1)), true); +} + +// Lock free thread local storage is too small to fit all the unique threads, +// and it spills to a map guarded by a mutex. +void test_large_number_of_tasks_with_spill() { + CounterFactory factory; + int num_threads = internal::random<int>(4, 32); + Eigen::ThreadPool thread_pool(num_threads); + Eigen::ThreadLocal<CounterFactory> counter(factory, 1); // This is too small + + int num_tasks = 10000; + Eigen::Barrier barrier(num_tasks); + + for (int i = 0; i < num_tasks; ++i) { + thread_pool.Schedule([&counter, &barrier]() { + Counter& local = counter.local(); + local.inc(); + barrier.Notify(); + }); + } + + barrier.Wait(); + + int total = 0; + std::unordered_set<std::thread::id> unique_threads; + + counter.ForEach([&](std::thread::id id, Counter& cnt) { + total += cnt.value(); + unique_threads.insert(id); + }); + + VERIFY_IS_EQUAL(total, num_tasks); + // Not all threads in a pool might be woken up to execute submitted tasks. + // Also thread_pool.Schedule() might use current thread if queue is full. + VERIFY_IS_EQUAL( + unique_threads.size() <= (static_cast<size_t>(num_threads + 1)), true); +} + +EIGEN_DECLARE_TEST(cxx11_tensor_thread_local) { + CALL_SUBTEST(test_simple_thread_local()); + CALL_SUBTEST(test_zero_sized_thread_local()); + CALL_SUBTEST(test_large_number_of_tasks_no_spill()); + CALL_SUBTEST(test_large_number_of_tasks_with_spill()); +} |