diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-20 18:35:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 18:39:18 -0700 |
commit | f283f3ac5d7b6de8cadc9c1cee6886b187319afd (patch) | |
tree | ccb4315c049ef96fb4aaf11bb1bc743e97386ead /tensorflow/core/lib | |
parent | 684b3e02e098cb6fda5569fb7f7990ff57248e5a (diff) |
Add an API which gives explicit control over shard sizes and introspection into the number of shards used. This is a variant of threadpool::parallelFor
PiperOrigin-RevId: 213920649
Diffstat (limited to 'tensorflow/core/lib')
-rw-r--r-- | tensorflow/core/lib/core/threadpool.cc | 49 | ||||
-rw-r--r-- | tensorflow/core/lib/core/threadpool.h | 14 | ||||
-rw-r--r-- | tensorflow/core/lib/core/threadpool_test.cc | 61 |
3 files changed, 124 insertions, 0 deletions
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc index 99684ae47b..9ccd911b0e 100644 --- a/tensorflow/core/lib/core/threadpool.cc +++ b/tensorflow/core/lib/core/threadpool.cc @@ -17,6 +17,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/platform/context.h" #include "tensorflow/core/platform/denormal.h" #include "tensorflow/core/platform/logging.h" @@ -120,6 +121,54 @@ void ThreadPool::Schedule(std::function<void()> fn) { impl_->Schedule(std::move(fn)); } +int ThreadPool::NumShardsUsedByTransformRangeConcurrently( + const int64 block_size, const int64 total) { + if (block_size <= 0 || total <= 1 || total <= block_size || + NumThreads() == 1) { + return 1; + } + return (total + block_size - 1) / block_size; +} + +// This functionality is similar to parallelFor, except that reasoning about +// the number of shards used is significantly easier. +void ThreadPool::TransformRangeConcurrently( + const int64 block_size, const int64 total, + const std::function<void(int64, int64)>& fn) { + const int num_shards_used = + NumShardsUsedByTransformRangeConcurrently(block_size, total); + if (num_shards_used == 1) { + fn(0, total); + return; + } + + // Adapted from Eigen's parallelFor implementation. + BlockingCounter counter(num_shards_used); + std::function<void(int64, int64)> handle_range = + [=, &handle_range, &counter, &fn](int64 first, int64 last) { + while (last - first > block_size) { + // Find something near the midpoint which is a multiple of block size. + const int64 mid = first + ((last - first) / 2 + block_size - 1) / + block_size * block_size; + Schedule([=, &handle_range]() { handle_range(mid, last); }); + last = mid; + } + // Single block or less, execute directly. + fn(first, last); + counter.DecrementCount(); // The shard is done. + }; + if (num_shards_used <= NumThreads()) { + // Avoid a thread hop by running the root of the tree and one block on the + // main thread. + handle_range(0, total); + } else { + // Execute the root in the thread pool to avoid running work on more than + // numThreads() threads. + Schedule([=, &handle_range]() { handle_range(0, total); }); + } + counter.Wait(); +} + void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit, std::function<void(int64, int64)> fn) { impl_->ParallelFor(total, cost_per_unit, std::move(fn)); diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h index 74df7c84a4..e14ad7ac64 100644 --- a/tensorflow/core/lib/core/threadpool.h +++ b/tensorflow/core/lib/core/threadpool.h @@ -59,6 +59,20 @@ class ThreadPool { // Schedules fn() for execution in the pool of threads. void Schedule(std::function<void()> fn); + // Requires 0 < block_size <= total. + // Spawns k threads and calls fn(i*block_size, (i+1)*block_size) from the + // ith thread (i>=0). When (i+1)*block_size > total, fn(i*block_size, total) + // is called instead. k = NumShardsUsedByTransformRangeConcurrently(...). + // Note that when there aren't enough threads in the pool to achieve full + // parallelism, function calls will be automatically queued. + void TransformRangeConcurrently(const int64 block_size, const int64 total, + const std::function<void(int64, int64)>& fn); + + // Returns the number of threads spawned by calling TransformRangeConcurrently + // with these parameters. + int NumShardsUsedByTransformRangeConcurrently(const int64 block_size, + const int64 total); + // ParallelFor shards the "total" units of work assuming each unit of work // having roughly "cost_per_unit" cost, in cycles. Each unit of work is // indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc index 320f3ebb83..db996b783f 100644 --- a/tensorflow/core/lib/core/threadpool_test.cc +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -61,6 +61,67 @@ TEST(ThreadPool, DoWork) { } } +void RunSharding(int64 block_size, int64 total, ThreadPool* threads) { + mutex mu; + int64 num_shards = 0; + int64 num_done_work = 0; + std::vector<bool> work(total, false); + threads->TransformRangeConcurrently( + block_size, total, + [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) { + VLOG(1) << "Shard [" << start << "," << end << ")"; + EXPECT_GE(start, 0); + EXPECT_LE(end, total); + mutex_lock l(mu); + ++num_shards; + for (; start < end; ++start) { + EXPECT_FALSE(work[start]); // No duplicate + ++num_done_work; + work[start] = true; + } + }); + LOG(INFO) << block_size << " " << total; + const int64 num_workers = (total + block_size - 1) / block_size; + EXPECT_EQ(num_done_work, total); + if (num_workers < threads->NumThreads()) { + // If the intention is to limit the parallelism explicitly, we'd + // better honor it. Ideally, even if per_thread_max_parallelism > + // num_workers, we should expect that Shard() implementation do + // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor + // tends to over-shard. + EXPECT_LE(num_shards, 1 + num_workers); + } +} + +// Adapted from work_sharder_test.cc +TEST(SparseUtilsTest, TransformRangeConcurrently) { + ThreadPool threads(Env::Default(), "test", 16); + for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) { + for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) { + const int64 total = block_size + diff; + RunSharding(block_size, total, &threads); + } + } +} + +TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) { + ThreadPool threads(Env::Default(), "test", 16); + EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 3 /* total */)); + EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 4 /* total */)); + EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 5 /* total */)); + EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 6 /* total */)); + EXPECT_EQ(3, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 7 /* total */)); + EXPECT_EQ(7, threads.NumShardsUsedByTransformRangeConcurrently( + 1 /* block_size */, 7 /* total */)); + EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently( + 0 /* block_size */, 7 /* total */)); +} + TEST(ThreadPool, ParallelFor) { Context outer_context(ContextKind::kThread); // Make ParallelFor use as many threads as possible. |