aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-20 18:35:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 18:39:18 -0700
commitf283f3ac5d7b6de8cadc9c1cee6886b187319afd (patch)
treeccb4315c049ef96fb4aaf11bb1bc743e97386ead /tensorflow/core/lib
parent684b3e02e098cb6fda5569fb7f7990ff57248e5a (diff)
Add an API which gives explicit control over shard sizes and introspection into the number of shards used. This is a variant of threadpool::parallelFor
PiperOrigin-RevId: 213920649
Diffstat (limited to 'tensorflow/core/lib')
-rw-r--r--tensorflow/core/lib/core/threadpool.cc49
-rw-r--r--tensorflow/core/lib/core/threadpool.h14
-rw-r--r--tensorflow/core/lib/core/threadpool_test.cc61
3 files changed, 124 insertions, 0 deletions
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index 99684ae47b..9ccd911b0e 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -17,6 +17,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
@@ -120,6 +121,54 @@ void ThreadPool::Schedule(std::function<void()> fn) {
impl_->Schedule(std::move(fn));
}
+int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
+ const int64 block_size, const int64 total) {
+ if (block_size <= 0 || total <= 1 || total <= block_size ||
+ NumThreads() == 1) {
+ return 1;
+ }
+ return (total + block_size - 1) / block_size;
+}
+
+// This functionality is similar to parallelFor, except that reasoning about
+// the number of shards used is significantly easier.
+void ThreadPool::TransformRangeConcurrently(
+ const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn) {
+ const int num_shards_used =
+ NumShardsUsedByTransformRangeConcurrently(block_size, total);
+ if (num_shards_used == 1) {
+ fn(0, total);
+ return;
+ }
+
+ // Adapted from Eigen's parallelFor implementation.
+ BlockingCounter counter(num_shards_used);
+ std::function<void(int64, int64)> handle_range =
+ [=, &handle_range, &counter, &fn](int64 first, int64 last) {
+ while (last - first > block_size) {
+ // Find something near the midpoint which is a multiple of block size.
+ const int64 mid = first + ((last - first) / 2 + block_size - 1) /
+ block_size * block_size;
+ Schedule([=, &handle_range]() { handle_range(mid, last); });
+ last = mid;
+ }
+ // Single block or less, execute directly.
+ fn(first, last);
+ counter.DecrementCount(); // The shard is done.
+ };
+ if (num_shards_used <= NumThreads()) {
+ // Avoid a thread hop by running the root of the tree and one block on the
+ // main thread.
+ handle_range(0, total);
+ } else {
+ // Execute the root in the thread pool to avoid running work on more than
+ // numThreads() threads.
+ Schedule([=, &handle_range]() { handle_range(0, total); });
+ }
+ counter.Wait();
+}
+
void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn) {
impl_->ParallelFor(total, cost_per_unit, std::move(fn));
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index 74df7c84a4..e14ad7ac64 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -59,6 +59,20 @@ class ThreadPool {
// Schedules fn() for execution in the pool of threads.
void Schedule(std::function<void()> fn);
+ // Requires 0 < block_size <= total.
+ // Spawns k threads and calls fn(i*block_size, (i+1)*block_size) from the
+ // ith thread (i>=0). When (i+1)*block_size > total, fn(i*block_size, total)
+ // is called instead. k = NumShardsUsedByTransformRangeConcurrently(...).
+ // Note that when there aren't enough threads in the pool to achieve full
+ // parallelism, function calls will be automatically queued.
+ void TransformRangeConcurrently(const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn);
+
+ // Returns the number of threads spawned by calling TransformRangeConcurrently
+ // with these parameters.
+ int NumShardsUsedByTransformRangeConcurrently(const int64 block_size,
+ const int64 total);
+
// ParallelFor shards the "total" units of work assuming each unit of work
// having roughly "cost_per_unit" cost, in cycles. Each unit of work is
// indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
index 320f3ebb83..db996b783f 100644
--- a/tensorflow/core/lib/core/threadpool_test.cc
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -61,6 +61,67 @@ TEST(ThreadPool, DoWork) {
}
}
+void RunSharding(int64 block_size, int64 total, ThreadPool* threads) {
+ mutex mu;
+ int64 num_shards = 0;
+ int64 num_done_work = 0;
+ std::vector<bool> work(total, false);
+ threads->TransformRangeConcurrently(
+ block_size, total,
+ [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) {
+ VLOG(1) << "Shard [" << start << "," << end << ")";
+ EXPECT_GE(start, 0);
+ EXPECT_LE(end, total);
+ mutex_lock l(mu);
+ ++num_shards;
+ for (; start < end; ++start) {
+ EXPECT_FALSE(work[start]); // No duplicate
+ ++num_done_work;
+ work[start] = true;
+ }
+ });
+ LOG(INFO) << block_size << " " << total;
+ const int64 num_workers = (total + block_size - 1) / block_size;
+ EXPECT_EQ(num_done_work, total);
+ if (num_workers < threads->NumThreads()) {
+ // If the intention is to limit the parallelism explicitly, we'd
+ // better honor it. Ideally, even if per_thread_max_parallelism >
+ // num_workers, we should expect that Shard() implementation do
+ // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor
+ // tends to over-shard.
+ EXPECT_LE(num_shards, 1 + num_workers);
+ }
+}
+
+// Adapted from work_sharder_test.cc
+TEST(SparseUtilsTest, TransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) {
+ for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) {
+ const int64 total = block_size + diff;
+ RunSharding(block_size, total, &threads);
+ }
+ }
+}
+
+TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 3 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 4 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 5 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 6 /* total */));
+ EXPECT_EQ(3, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 7 /* total */));
+ EXPECT_EQ(7, threads.NumShardsUsedByTransformRangeConcurrently(
+ 1 /* block_size */, 7 /* total */));
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 0 /* block_size */, 7 /* total */));
+}
+
TEST(ThreadPool, ParallelFor) {
Context outer_context(ContextKind::kThread);
// Make ParallelFor use as many threads as possible.