diff options
author | 2016-12-06 12:17:17 -0800 | |
---|---|---|
committer | 2016-12-06 13:01:13 -0800 | |
commit | d0f8bda2c7dad89474ee883c4c840e5ea8429f67 (patch) | |
tree | 1f2cbc2f2e16bb9840c9900bfa78d45d24747da5 /tensorflow/core/lib/core/threadpool.cc | |
parent | 34176742ab1ae12774f1bbe72f4fa2b6d078ce64 (diff) |
Add a ParallelForWithWorkerId function.
This allows allocating a partial result for each thread. Writing to the partial result indexed by the worker id is thread safe.
Change: 141209775
Diffstat (limited to 'tensorflow/core/lib/core/threadpool.cc')
-rw-r--r-- | tensorflow/core/lib/core/threadpool.cc | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc index a2245bb28e..59d77a3734 100644 --- a/tensorflow/core/lib/core/threadpool.cc +++ b/tensorflow/core/lib/core/threadpool.cc @@ -123,6 +123,15 @@ void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit, impl_->ParallelFor(total, cost_per_unit, std::move(fn)); } +void ThreadPool::ParallelForWithWorkerId( + int64 total, int64 cost_per_unit, + const std::function<void(int64, int64, int)>& fn) { + impl_->ParallelFor(total, cost_per_unit, + [this, &fn](int64 start, int64 limit) { + fn(start, limit, CurrentThreadId()); + }); +} + int ThreadPool::NumThreads() const { return impl_->NumThreads(); } int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); } |