aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/work_sharder.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/work_sharder.h')
-rw-r--r--tensorflow/core/util/work_sharder.h33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
new file mode 100644
index 0000000000..1ea2cf4397
--- /dev/null
+++ b/tensorflow/core/util/work_sharder.h
@@ -0,0 +1,33 @@
+#ifndef TENSORFLOW_UTIL_WORK_SHARDER_H_
+#define TENSORFLOW_UTIL_WORK_SHARDER_H_
+
+#include <functional>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+
+// Shards the "total" unit of work assuming each unit of work having
+// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
+// total - 1. Each shard contains 1 or more units of work and the
+// total cost of each shard is roughly the same. The total number of
+// shards is no more than num_workers. The calling thread and the
+// "workers" are used to compute each shard (calling work(start,
+// limit). A common configuration is that "workers" is a thread pool
+// with "num_workers" threads.
+//
+// "work" should be a callable taking (int64, int64) arguments.
+// work(start, limit) computes the work units from [start,
+// limit), i.e., [start, limit) is a shard.
+//
+// REQUIRES: num_workers >= 0
+// REQUIRES: workers != nullptr
+// REQUIRES: total >= 0
+// REQUIRES: cost_per_unit >= 0
+void Shard(int num_workers, thread::ThreadPool* workers, int64 total,
+ int64 cost_per_unit, std::function<void(int64, int64)> work);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_WORK_SHARDER_H_