aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
authorGravatar Christopher Olston <olston@google.com>2017-11-29 13:58:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-29 14:02:28 -0800
commit1d0b07351d901334b33565595d4c23607f11cc27 (patch)
treea662da76cd373d7cee2e95c8b1ce49e7e1134662 /tensorflow/contrib/batching
parent48347ee4105d78d8f36ba8645953b75cb5280c4c (diff)
Add a way to query a batch scheduler to determine the max task size.
A layer on top of the batcher could use this interface to pre-split large tasks that exceed the max batch size. PiperOrigin-RevId: 177359263
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h2
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc1
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler.h4
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler_test.cc1
-rw-r--r--tensorflow/contrib/batching/batch_scheduler.h4
-rw-r--r--tensorflow/contrib/batching/shared_batch_scheduler.h5
-rw-r--r--tensorflow/contrib/batching/shared_batch_scheduler_test.cc1
7 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
index 6ed177e001..9e32bee505 100644
--- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
+++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
@@ -208,6 +208,8 @@ class ASBSQueue : public BatchScheduler<TaskType> {
// place any more tasks in this batch.
void ReleaseBatch(const ASBSBatch<TaskType>* batch);
+ size_t max_task_size() const override { return options_.max_batch_size; }
+
private:
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
const QueueOptions options_;
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
index a07cd6d834..e2aac54eeb 100644
--- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
+++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
@@ -186,6 +186,7 @@ TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) {
queue_options.max_enqueued_batches = 2;
TF_ASSERT_OK(
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
+ EXPECT_EQ(10, queue_0->max_task_size());
queue_options.max_batch_size = 0;
// Queue must have max_batch_size > 0.
EXPECT_FALSE(
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h
index 9d3805fbaf..91065db249 100644
--- a/tensorflow/contrib/batching/basic_batch_scheduler.h
+++ b/tensorflow/contrib/batching/basic_batch_scheduler.h
@@ -192,6 +192,10 @@ class BasicBatchScheduler : public BatchScheduler<TaskType> {
size_t NumEnqueuedTasks() const override;
size_t SchedulingCapacity() const override;
+ size_t max_task_size() const override {
+ return shared_scheduler_queue_->max_task_size();
+ }
+
private:
explicit BasicBatchScheduler(
std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue);
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc
index e020301795..187823151c 100644
--- a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc
+++ b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc
@@ -73,6 +73,7 @@ TEST(BasicBatchSchedulerTest, Basic) {
std::unique_ptr<BasicBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
BasicBatchScheduler<FakeTask>::Create(options, callback, &scheduler));
+ EXPECT_EQ(10, scheduler->max_task_size());
EXPECT_EQ(0, scheduler->NumEnqueuedTasks());
EXPECT_EQ(3 * 10, scheduler->SchedulingCapacity());
TF_ASSERT_OK(ScheduleTask(3, scheduler.get()));
diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h
index a5072f439a..e18cf6c350 100644
--- a/tensorflow/contrib/batching/batch_scheduler.h
+++ b/tensorflow/contrib/batching/batch_scheduler.h
@@ -178,6 +178,10 @@ class BatchScheduler {
// This method is useful for monitoring, or for guaranteeing a future slot in
// the schedule (but being mindful about the caveats listed above).
virtual size_t SchedulingCapacity() const = 0;
+
+ // Returns the maximum allowed size of tasks submitted to the scheduler. (This
+ // is typically equal to a configured maximum batch size.)
+ virtual size_t max_task_size() const = 0;
};
//////////
diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h
index 41a3f99137..1d2158062e 100644
--- a/tensorflow/contrib/batching/shared_batch_scheduler.h
+++ b/tensorflow/contrib/batching/shared_batch_scheduler.h
@@ -248,6 +248,9 @@ class Queue {
// BatchScheduler::SchedulingCapacity().
size_t SchedulingCapacity() const;
+ // Returns the maximum allowed size of tasks submitted to the queue.
+ size_t max_task_size() const { return options_.max_batch_size; }
+
// Called by a thread that is ready to process a batch, to request one from
// this queue. Either returns a batch that is ready to be processed, or
// nullptr if the queue declines to schedule a batch at this time. If it
@@ -338,6 +341,8 @@ class QueueHandle : public BatchScheduler<TaskType> {
size_t NumEnqueuedTasks() const override;
size_t SchedulingCapacity() const override;
+ size_t max_task_size() const override { return queue_->max_task_size(); }
+
private:
// The scheduler that owns 'queue_'.
std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler_;
diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc
index 3e924ae5f1..3ac79a8fdc 100644
--- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc
+++ b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc
@@ -429,6 +429,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) {
queue_options.max_enqueued_batches = max_enqueued_batches;
std::unique_ptr<BatchScheduler<FakeTask>> queue;
TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue));
+ EXPECT_EQ(2, queue->max_task_size());
EXPECT_EQ(0, queue->NumEnqueuedTasks());
EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity());