aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-04-11 01:28:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-11 02:33:23 -0700
commit017498bc5d217b32d795545bca902ffd8246b110 (patch)
treec77b7b30d495131581aee8088964bdfa8a235816
parenteb161ecd0c756aa5c975cbc867e933f19d938b77 (diff)
tensorflow: support usage of eigen thread pool
Use eigen ThreadPool instead of tensorflow one if TENSORFLOW_USE_EIGEN_THREADPOOL is defined. This will allow to switch to the new non-blocking ThreadPool. Change: 119512280
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/resources_test.cc1
-rw-r--r--tensorflow/core/lib/core/threadpool.cc143
-rw-r--r--tensorflow/core/lib/core/threadpool.h26
-rw-r--r--tensorflow/core/lib/core/threadpool_test.cc1
4 files changed, 124 insertions, 47 deletions
diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc b/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc
index 4b4c2f5fd7..060d29daab 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc
+++ b/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc
@@ -164,7 +164,6 @@ TEST_F(DataByExampleTest, VisitUnavailable) {
signal(&updated_data);
});
wait(&completed_visit);
- EXPECT_FALSE(thread_pool.HasPendingClosures());
EXPECT_TRUE(errors::IsUnavailable(status));
}
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index 07ace0560a..f4e952826a 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -15,6 +15,16 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
+#ifdef TENSORFLOW_USE_EIGEN_THREADPOOL
+#define EIGEN_USE_THREADS
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#else
+#include <deque>
+#include <thread>
+#include <vector>
+#endif
+
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
@@ -24,26 +34,97 @@ limitations under the License.
namespace tensorflow {
namespace thread {
-struct ThreadPool::Waiter {
- condition_variable cv;
- bool ready;
+#ifdef TENSORFLOW_USE_EIGEN_THREADPOOL
+
+struct EigenEnvironment {
+ typedef Thread EnvThread;
+ struct Task {
+ std::function<void()> f;
+ uint64 trace_id;
+ };
+
+ Env* const env_;
+ const ThreadOptions thread_options_;
+ const string name_;
+
+ EigenEnvironment(Env* env, const ThreadOptions& thread_options,
+ const string& name)
+ : env_(env), thread_options_(thread_options), name_(name) {}
+
+ EnvThread* CreateThread(std::function<void()> f) {
+ return env_->StartThread(thread_options_, name_, [=]() {
+ // Set the processor flag to flush denormals to zero
+ port::ScopedFlushDenormal flush;
+ f();
+ });
+ }
+
+ Task CreateTask(std::function<void()> f) {
+ uint64 id = 0;
+ if (port::Tracing::IsActive()) {
+ id = port::Tracing::UniqueId();
+ port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
+ id);
+ }
+ return Task{std::move(f), id};
+ }
+
+ void ExecuteTask(const Task& t) {
+ if (t.trace_id != 0) {
+ port::Tracing::ScopedActivity region(
+ port::Tracing::EventCategory::kRunClosure, t.trace_id);
+ t.f();
+ } else {
+ t.f();
+ }
+ }
+};
+
+struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
+ Impl(Env* env, const ThreadOptions& thread_options, const string& name,
+ int num_threads)
+ : Eigen::ThreadPoolTempl<EigenEnvironment>(
+ num_threads, EigenEnvironment(env, thread_options, name)) {}
};
-ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
- : ThreadPool(env, ThreadOptions(), name, num_threads) {}
+#else
-ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
+struct ThreadPool::Impl {
+ Impl(Env* env, const ThreadOptions& thread_options, const string& name,
+ int num_threads);
+ ~Impl();
+ void Schedule(std::function<void()> fn);
+
+ private:
+ struct Waiter {
+ condition_variable cv;
+ bool ready;
+ };
+
+ struct Task {
+ std::function<void()> fn;
+ uint64 id;
+ };
+
+ void WorkerLoop();
+
+ const string name_;
+ mutex mu_;
+ std::vector<Thread*> threads_; // All threads
+ std::vector<Waiter*> waiters_; // Stack of waiting threads.
+ std::deque<Task> pending_; // Queue of pending work
+};
+
+ThreadPool::Impl::Impl(Env* env, const ThreadOptions& thread_options,
const string& name, int num_threads)
: name_(name) {
- CHECK_GE(num_threads, 1);
- string name_prefix = "tf_" + name_;
for (int i = 0; i < num_threads; i++) {
- threads_.push_back(env->StartThread(thread_options, name_prefix,
- [this]() { WorkerLoop(); }));
+ threads_.push_back(
+ env->StartThread(thread_options, name, [this]() { WorkerLoop(); }));
}
}
-ThreadPool::~ThreadPool() {
+ThreadPool::Impl::~Impl() {
{
// Wait for all work to get done.
mutex_lock l(mu_);
@@ -66,13 +147,7 @@ ThreadPool::~ThreadPool() {
}
}
-bool ThreadPool::HasPendingClosures() const {
- mutex_lock l(mu_);
- return pending_.size() != 0;
-}
-
-void ThreadPool::Schedule(std::function<void()> fn) {
- CHECK(fn != nullptr);
+void ThreadPool::Impl::Schedule(std::function<void()> fn) {
uint64 id = 0;
if (port::Tracing::IsActive()) {
id = port::Tracing::UniqueId();
@@ -90,7 +165,7 @@ void ThreadPool::Schedule(std::function<void()> fn) {
}
}
-void ThreadPool::WorkerLoop() {
+void ThreadPool::Impl::WorkerLoop() {
// Set the processor flag to flush denormals to zero
port::ScopedFlushDenormal flush;
@@ -107,22 +182,40 @@ void ThreadPool::WorkerLoop() {
}
}
// Pick up pending work
- Item item = pending_.front();
+ Task t = pending_.front();
pending_.pop_front();
- if (item.fn == nullptr) {
+ if (t.fn == nullptr) {
break;
}
mu_.unlock();
- if (item.id != 0) {
+ if (t.id != 0) {
port::Tracing::ScopedActivity region(
- port::Tracing::EventCategory::kRunClosure, item.id);
- item.fn();
+ port::Tracing::EventCategory::kRunClosure, t.id);
+ t.fn();
} else {
- item.fn();
+ t.fn();
}
mu_.lock();
}
}
+#endif
+
+ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
+ : ThreadPool(env, ThreadOptions(), name, num_threads) {}
+
+ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
+ const string& name, int num_threads) {
+ CHECK_GE(num_threads, 1);
+ impl_.reset(
+ new ThreadPool::Impl(env, thread_options, "tf_" + name, num_threads));
+}
+
+ThreadPool::~ThreadPool() {}
+
+void ThreadPool::Schedule(std::function<void()> fn) {
+ CHECK(fn != nullptr);
+ impl_->Schedule(std::move(fn));
+}
} // namespace thread
} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index ef37dcf2d9..ae709e0824 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -16,13 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_
#define TENSORFLOW_LIB_CORE_THREADPOOL_H_
-#include <deque>
#include <functional>
-#include <thread>
-#include <vector>
+#include <memory>
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -45,28 +42,15 @@ class ThreadPool {
// Wait until all scheduled work has finished and then destroy the
// set of threads.
- virtual ~ThreadPool();
+ ~ThreadPool();
// Schedule fn() for execution in the pool of threads.
- virtual void Schedule(std::function<void()> fn);
+ void Schedule(std::function<void()> fn);
- virtual bool HasPendingClosures() const;
+ struct Impl;
private:
- struct Waiter;
- struct Item {
- std::function<void()> fn;
- uint64 id;
- };
-
- void WorkerLoop();
-
- const string name_;
- mutable mutex mu_;
- std::vector<Thread*> threads_; // All threads
- std::vector<Waiter*> waiters_; // Stack of waiting threads.
- std::deque<Item> pending_; // Queue of pending work
-
+ std::unique_ptr<Impl> impl_;
TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool);
};
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
index 59ca99c299..f0edebdd62 100644
--- a/tensorflow/core/lib/core/threadpool_test.cc
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <atomic>
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"