diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/lib/core/threadpool.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/lib/core/threadpool.cc')
-rw-r--r-- | tensorflow/core/lib/core/threadpool.cc | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc new file mode 100644 index 0000000000..e9b84d3102 --- /dev/null +++ b/tensorflow/core/lib/core/threadpool.cc @@ -0,0 +1,108 @@ +#include "tensorflow/core/lib/core/threadpool.h" + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" + +namespace tensorflow { +namespace thread { + +struct ThreadPool::Waiter { + condition_variable cv; + bool ready; +}; + +ThreadPool::ThreadPool(Env* env, const string& name, int num_threads) + : ThreadPool(env, ThreadOptions(), name, num_threads) {} + +ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads) + : name_(name) { + CHECK_GE(num_threads, 1); + string name_prefix = "tf_" + name_; + for (int i = 0; i < num_threads; i++) { + threads_.push_back(env->StartThread(thread_options, name_prefix, + [this]() { WorkerLoop(); })); + } +} + +ThreadPool::~ThreadPool() { + { + // Wait for all work to get done. + mutex_lock l(mu_); + + // Inform every thread to exit. + for (size_t i = 0; i < threads_.size(); ++i) { + pending_.push_back({nullptr, 0}); + } + + // Wakeup all waiters. + for (auto w : waiters_) { + w->ready = true; + w->cv.notify_one(); + } + } + + // Wait for threads to finish. + for (auto t : threads_) { + delete t; + } +} + +bool ThreadPool::HasPendingClosures() const { + mutex_lock l(mu_); + return pending_.size() != 0; +} + +void ThreadPool::Schedule(std::function<void()> fn) { + CHECK(fn != nullptr); + uint64 id = 0; + if (port::Tracing::IsActive()) { + id = port::Tracing::UniqueId(); + port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure, + id); + } + + mutex_lock l(mu_); + pending_.push_back({fn, id}); + if (!waiters_.empty()) { + Waiter* w = waiters_.back(); + waiters_.pop_back(); + w->ready = true; + w->cv.notify_one(); + } +} + +void ThreadPool::WorkerLoop() { + port::Tracing::RegisterCurrentThread(name_.c_str()); + mutex_lock l(mu_); + Waiter w; + while (true) { + while (pending_.empty()) { + // Wait for work to be assigned to me + w.ready = false; + waiters_.push_back(&w); + while (!w.ready) { + w.cv.wait(l); + } + } + // Pick up pending work + Item item = pending_.front(); + pending_.pop_front(); + if (item.fn == nullptr) { + break; + } + mu_.unlock(); + if (item.id != 0) { + port::Tracing::ScopedActivity region( + port::Tracing::EventCategory::kRunClosure, item.id); + item.fn(); + } else { + item.fn(); + } + mu_.lock(); + } +} + +} // namespace thread +} // namespace tensorflow |