From 017498bc5d217b32d795545bca902ffd8246b110 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 11 Apr 2016 01:28:29 -0800 Subject: tensorflow: support usage of eigen thread pool Use eigen ThreadPool instead of tensorflow one if TENSORFLOW_USE_EIGEN_THREADPOOL is defined. This will allow to switch to the new non-blocking ThreadPool. Change: 119512280 --- .../linear_optimizer/kernels/resources_test.cc | 1 - tensorflow/core/lib/core/threadpool.cc | 143 +++++++++++++++++---- tensorflow/core/lib/core/threadpool.h | 26 +--- tensorflow/core/lib/core/threadpool_test.cc | 1 + 4 files changed, 124 insertions(+), 47 deletions(-) diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc b/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc index 4b4c2f5fd7..060d29daab 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc +++ b/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc @@ -164,7 +164,6 @@ TEST_F(DataByExampleTest, VisitUnavailable) { signal(&updated_data); }); wait(&completed_visit); - EXPECT_FALSE(thread_pool.HasPendingClosures()); EXPECT_TRUE(errors::IsUnavailable(status)); } diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc index 07ace0560a..f4e952826a 100644 --- a/tensorflow/core/lib/core/threadpool.cc +++ b/tensorflow/core/lib/core/threadpool.cc @@ -15,6 +15,16 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" +#ifdef TENSORFLOW_USE_EIGEN_THREADPOOL +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#else +#include +#include +#include +#endif + #include "tensorflow/core/platform/denormal.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" @@ -24,26 +34,97 @@ limitations under the License. namespace tensorflow { namespace thread { -struct ThreadPool::Waiter { - condition_variable cv; - bool ready; +#ifdef TENSORFLOW_USE_EIGEN_THREADPOOL + +struct EigenEnvironment { + typedef Thread EnvThread; + struct Task { + std::function f; + uint64 trace_id; + }; + + Env* const env_; + const ThreadOptions thread_options_; + const string name_; + + EigenEnvironment(Env* env, const ThreadOptions& thread_options, + const string& name) + : env_(env), thread_options_(thread_options), name_(name) {} + + EnvThread* CreateThread(std::function f) { + return env_->StartThread(thread_options_, name_, [=]() { + // Set the processor flag to flush denormals to zero + port::ScopedFlushDenormal flush; + f(); + }); + } + + Task CreateTask(std::function f) { + uint64 id = 0; + if (port::Tracing::IsActive()) { + id = port::Tracing::UniqueId(); + port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure, + id); + } + return Task{std::move(f), id}; + } + + void ExecuteTask(const Task& t) { + if (t.trace_id != 0) { + port::Tracing::ScopedActivity region( + port::Tracing::EventCategory::kRunClosure, t.trace_id); + t.f(); + } else { + t.f(); + } + } +}; + +struct ThreadPool::Impl : Eigen::ThreadPoolTempl { + Impl(Env* env, const ThreadOptions& thread_options, const string& name, + int num_threads) + : Eigen::ThreadPoolTempl( + num_threads, EigenEnvironment(env, thread_options, name)) {} }; -ThreadPool::ThreadPool(Env* env, const string& name, int num_threads) - : ThreadPool(env, ThreadOptions(), name, num_threads) {} +#else -ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, +struct ThreadPool::Impl { + Impl(Env* env, const ThreadOptions& thread_options, const string& name, + int num_threads); + ~Impl(); + void Schedule(std::function fn); + + private: + struct Waiter { + condition_variable cv; + bool ready; + }; + + struct Task { + std::function fn; + uint64 id; + }; + + void WorkerLoop(); + + const string name_; + mutex mu_; + std::vector threads_; // All threads + std::vector waiters_; // Stack of waiting threads. + std::deque pending_; // Queue of pending work +}; + +ThreadPool::Impl::Impl(Env* env, const ThreadOptions& thread_options, const string& name, int num_threads) : name_(name) { - CHECK_GE(num_threads, 1); - string name_prefix = "tf_" + name_; for (int i = 0; i < num_threads; i++) { - threads_.push_back(env->StartThread(thread_options, name_prefix, - [this]() { WorkerLoop(); })); + threads_.push_back( + env->StartThread(thread_options, name, [this]() { WorkerLoop(); })); } } -ThreadPool::~ThreadPool() { +ThreadPool::Impl::~Impl() { { // Wait for all work to get done. mutex_lock l(mu_); @@ -66,13 +147,7 @@ ThreadPool::~ThreadPool() { } } -bool ThreadPool::HasPendingClosures() const { - mutex_lock l(mu_); - return pending_.size() != 0; -} - -void ThreadPool::Schedule(std::function fn) { - CHECK(fn != nullptr); +void ThreadPool::Impl::Schedule(std::function fn) { uint64 id = 0; if (port::Tracing::IsActive()) { id = port::Tracing::UniqueId(); @@ -90,7 +165,7 @@ void ThreadPool::Schedule(std::function fn) { } } -void ThreadPool::WorkerLoop() { +void ThreadPool::Impl::WorkerLoop() { // Set the processor flag to flush denormals to zero port::ScopedFlushDenormal flush; @@ -107,22 +182,40 @@ void ThreadPool::WorkerLoop() { } } // Pick up pending work - Item item = pending_.front(); + Task t = pending_.front(); pending_.pop_front(); - if (item.fn == nullptr) { + if (t.fn == nullptr) { break; } mu_.unlock(); - if (item.id != 0) { + if (t.id != 0) { port::Tracing::ScopedActivity region( - port::Tracing::EventCategory::kRunClosure, item.id); - item.fn(); + port::Tracing::EventCategory::kRunClosure, t.id); + t.fn(); } else { - item.fn(); + t.fn(); } mu_.lock(); } } +#endif + +ThreadPool::ThreadPool(Env* env, const string& name, int num_threads) + : ThreadPool(env, ThreadOptions(), name, num_threads) {} + +ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads) { + CHECK_GE(num_threads, 1); + impl_.reset( + new ThreadPool::Impl(env, thread_options, "tf_" + name, num_threads)); +} + +ThreadPool::~ThreadPool() {} + +void ThreadPool::Schedule(std::function fn) { + CHECK(fn != nullptr); + impl_->Schedule(std::move(fn)); +} } // namespace thread } // namespace tensorflow diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h index ef37dcf2d9..ae709e0824 100644 --- a/tensorflow/core/lib/core/threadpool.h +++ b/tensorflow/core/lib/core/threadpool.h @@ -16,13 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_ #define TENSORFLOW_LIB_CORE_THREADPOOL_H_ -#include #include -#include -#include +#include #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -45,28 +42,15 @@ class ThreadPool { // Wait until all scheduled work has finished and then destroy the // set of threads. - virtual ~ThreadPool(); + ~ThreadPool(); // Schedule fn() for execution in the pool of threads. - virtual void Schedule(std::function fn); + void Schedule(std::function fn); - virtual bool HasPendingClosures() const; + struct Impl; private: - struct Waiter; - struct Item { - std::function fn; - uint64 id; - }; - - void WorkerLoop(); - - const string name_; - mutable mutex mu_; - std::vector threads_; // All threads - std::vector waiters_; // Stack of waiting threads. - std::deque pending_; // Queue of pending work - + std::unique_ptr impl_; TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool); }; diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc index 59ca99c299..f0edebdd62 100644 --- a/tensorflow/core/lib/core/threadpool_test.cc +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -- cgit v1.2.3