aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/core/threadpool.cc
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2016-12-06 12:17:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-06 13:01:13 -0800
commitd0f8bda2c7dad89474ee883c4c840e5ea8429f67 (patch)
tree1f2cbc2f2e16bb9840c9900bfa78d45d24747da5 /tensorflow/core/lib/core/threadpool.cc
parent34176742ab1ae12774f1bbe72f4fa2b6d078ce64 (diff)
Add a ParallelForWithWorkerId function.
This allows allocating a partial result for each thread. Writing to the partial result indexed by the worker id is thread safe. Change: 141209775
Diffstat (limited to 'tensorflow/core/lib/core/threadpool.cc')
-rw-r--r--tensorflow/core/lib/core/threadpool.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index a2245bb28e..59d77a3734 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -123,6 +123,15 @@ void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
impl_->ParallelFor(total, cost_per_unit, std::move(fn));
}
+void ThreadPool::ParallelForWithWorkerId(
+ int64 total, int64 cost_per_unit,
+ const std::function<void(int64, int64, int)>& fn) {
+ impl_->ParallelFor(total, cost_per_unit,
+ [this, &fn](int64 start, int64 limit) {
+ fn(start, limit, CurrentThreadId());
+ });
+}
+
int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); }