diff options
author | Christopher Olston <olston@google.com> | 2017-11-29 13:58:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-29 14:02:28 -0800 |
commit | 1d0b07351d901334b33565595d4c23607f11cc27 (patch) | |
tree | a662da76cd373d7cee2e95c8b1ce49e7e1134662 /tensorflow/contrib/batching | |
parent | 48347ee4105d78d8f36ba8645953b75cb5280c4c (diff) |
Add a way to query a batch scheduler to determine the max task size.
A layer on top of the batcher could use this interface to pre-split large tasks that exceed the max batch size.
PiperOrigin-RevId: 177359263
Diffstat (limited to 'tensorflow/contrib/batching')
7 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index 6ed177e001..9e32bee505 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -208,6 +208,8 @@ class ASBSQueue : public BatchScheduler<TaskType> { // place any more tasks in this batch. void ReleaseBatch(const ASBSBatch<TaskType>* batch); + size_t max_task_size() const override { return options_.max_batch_size; } + private: std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_; const QueueOptions options_; diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc index a07cd6d834..e2aac54eeb 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc @@ -186,6 +186,7 @@ TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { queue_options.max_enqueued_batches = 2; TF_ASSERT_OK( scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); + EXPECT_EQ(10, queue_0->max_task_size()); queue_options.max_batch_size = 0; // Queue must have max_batch_size > 0. EXPECT_FALSE( diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h index 9d3805fbaf..91065db249 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler.h +++ b/tensorflow/contrib/batching/basic_batch_scheduler.h @@ -192,6 +192,10 @@ class BasicBatchScheduler : public BatchScheduler<TaskType> { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { + return shared_scheduler_queue_->max_task_size(); + } + private: explicit BasicBatchScheduler( std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue); diff --git a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc index e020301795..187823151c 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc @@ -73,6 +73,7 @@ TEST(BasicBatchSchedulerTest, Basic) { std::unique_ptr<BasicBatchScheduler<FakeTask>> scheduler; TF_ASSERT_OK( BasicBatchScheduler<FakeTask>::Create(options, callback, &scheduler)); + EXPECT_EQ(10, scheduler->max_task_size()); EXPECT_EQ(0, scheduler->NumEnqueuedTasks()); EXPECT_EQ(3 * 10, scheduler->SchedulingCapacity()); TF_ASSERT_OK(ScheduleTask(3, scheduler.get())); diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index a5072f439a..e18cf6c350 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -178,6 +178,10 @@ class BatchScheduler { // This method is useful for monitoring, or for guaranteeing a future slot in // the schedule (but being mindful about the caveats listed above). virtual size_t SchedulingCapacity() const = 0; + + // Returns the maximum allowed size of tasks submitted to the scheduler. (This + // is typically equal to a configured maximum batch size.) + virtual size_t max_task_size() const = 0; }; ////////// diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h index 41a3f99137..1d2158062e 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/shared_batch_scheduler.h @@ -248,6 +248,9 @@ class Queue { // BatchScheduler::SchedulingCapacity(). size_t SchedulingCapacity() const; + // Returns the maximum allowed size of tasks submitted to the queue. + size_t max_task_size() const { return options_.max_batch_size; } + // Called by a thread that is ready to process a batch, to request one from // this queue. Either returns a batch that is ready to be processed, or // nullptr if the queue declines to schedule a batch at this time. If it @@ -338,6 +341,8 @@ class QueueHandle : public BatchScheduler<TaskType> { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { return queue_->max_task_size(); } + private: // The scheduler that owns 'queue_'. std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler_; diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc index 3e924ae5f1..3ac79a8fdc 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc @@ -429,6 +429,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) { queue_options.max_enqueued_batches = max_enqueued_batches; std::unique_ptr<BatchScheduler<FakeTask>> queue; TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue)); + EXPECT_EQ(2, queue->max_task_size()); EXPECT_EQ(0, queue->NumEnqueuedTasks()); EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity()); |