aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-20 13:19:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-20 13:23:41 -0700
commitd2d9a6c7cc3b4f8c068054082a0fa2f2b95bb3d6 (patch)
treef4153a1924b91cc025dc7640689a2aaee8290a4c /tensorflow/contrib/batching
parente65fbbc9dc608d97977b17e05250b015d65aa027 (diff)
Add AdaptiveSharedBatchScheduler which processes batches at a variable rate which can be adjusted based on external feedback. For reasonable feedback, this scheduler should deliver better latency than the SharedBatchScheduler.
PiperOrigin-RevId: 172924803
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/BUILD22
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h463
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc438
-rw-r--r--tensorflow/contrib/batching/batch_scheduler.h2
4 files changed, 924 insertions, 1 deletions
diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD
index 1555a3427f..ae3f48f1b2 100644
--- a/tensorflow/contrib/batching/BUILD
+++ b/tensorflow/contrib/batching/BUILD
@@ -70,6 +70,28 @@ tf_cc_test(
)
cc_library(
+ name = "adaptive_shared_batch_scheduler",
+ hdrs = ["adaptive_shared_batch_scheduler.h"],
+ deps = [
+ ":batch_scheduler",
+ "//tensorflow/contrib/batching/util:periodic_function_dynamic",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "adaptive_shared_batch_scheduler_test",
+ srcs = ["adaptive_shared_batch_scheduler_test.cc"],
+ deps = [
+ ":adaptive_shared_batch_scheduler",
+ "//tensorflow/contrib/batching/test_util:fake_clock_env",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "basic_batch_scheduler",
hdrs = ["basic_batch_scheduler.h"],
deps = [
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
new file mode 100644
index 0000000000..ac32f09639
--- /dev/null
+++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
@@ -0,0 +1,463 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
+
+#include <functional>
+#include <memory>
+#include <queue>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/batching/batch_scheduler.h"
+#include "tensorflow/contrib/batching/util/periodic_function.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace serving {
+namespace internal {
+template <typename TaskType>
+class ASBSBatch;
+
+template <typename TaskType>
+class ASBSQueue;
+} // namespace internal
+
+// Shared batch scheduler designed to minimize latency. The scheduler keeps
+// track of a number of queues (one per model or model version) which are
+// continuously enqueuing requests. The scheduler groups the requests into
+// batches which it periodically sends off for processing (see
+// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler
+// prioritizes batches by age (i.e. the batch's oldest request) irrespective of
+// queue. The scheduler will process the oldest batch at an adjustable rate,
+// regardless of batch size. The user can provide feedback to help set this rate
+// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc).
+//
+// The rate (or rather, the corresponding period) is adjusted each time a batch
+// is processed, using an exponentially weighted moving average to smooth
+// potentially noisy feedback:
+// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N
+// period *= (1 + K * emwa_feedback)
+//
+// Some potential use cases:
+// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
+// involves serial processing by a device, from a latency perspective it is
+// desirable to keep the device evenly loaded, avoiding the need to wait for
+// the device to process prior batches.
+// feedback = num_pending_on_device() - desired_pending.
+// CPU utilization - If the batch processing is cpu dominated, you can reap
+// latency gains when underutilized by increasing the processing rate, but
+// back the rate off when the load increases to avoid overload.
+// feedback = cpu_rate() - desired_cpu_rate.
+
+template <typename TaskType>
+class AdaptiveSharedBatchScheduler
+ : public std::enable_shared_from_this<
+ AdaptiveSharedBatchScheduler<TaskType>> {
+ public:
+ struct Options {
+ // The name to use for the pool of batch threads.
+ string thread_pool_name = {"batch_threads"};
+ // Number of batch processing threads; equivalently the maximum number of
+ // concurrently running batches.
+ int64 num_batch_threads = port::NumSchedulableCPUs();
+ // The environment to use (typically only overridden by test code).
+ Env* env = Env::Default();
+ // Initial batch scheduling period in microseconds. Will be altered for
+ // non-zero rate_feedback.
+ double initial_scheduling_period_micros = 500;
+ // Minimum batch scheduling period in microseconds. Recommend setting this
+ // value greater than 0, otherwise it may take a while to recover from a
+ // sustained time of negative scheduling_period_feedback (which may occur
+ // under low load).
+ double min_scheduling_period_micros = 100;
+ // Maximum batch scheduling period in microseconds.
+ double max_scheduling_period_micros = 10000;
+ // Feedback function used to modify the scheduling period each time a batch
+ // is scheduled. Should return values roughly O(1), with positive values
+ // resulting in an increased period.
+ std::function<double()> scheduling_period_feedback = [] { return 0.; };
+ // To handle potentially noisy scheduling_period_feedback, the period is
+ // adjusted using an exponentially weighted moving average over the previous
+ // feedback_smoothing_batches batches. Must be greater than 0.
+ int64 feedback_smoothing_batches = 10;
+ };
+
+ // Ownership is shared between the caller of Create() and any queues created
+ // via AddQueue().
+ static Status Create(
+ const Options& options,
+ std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
+
+ struct QueueOptions {
+ // Maximum size of each batch.
+ int max_batch_size = 1000;
+ // Maximum number of enqueued (i.e. non-scheduled) batches.
+ int max_enqueued_batches = 10;
+ };
+
+ using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
+
+ // Adds queue (and its callback) to be managed by this scheduler.
+ Status AddQueue(const QueueOptions& options,
+ BatchProcessor process_batch_callback,
+ std::unique_ptr<BatchScheduler<TaskType>>* queue);
+
+ private:
+ // access to AddBatch, RemoveQueue, GetEnv.
+ friend class internal::ASBSQueue<TaskType>;
+
+ explicit AdaptiveSharedBatchScheduler(const Options& options);
+
+ // Batch scheduling function which runs every scheduling_period_ microseconds.
+ void ProcessOneBatch();
+
+ // Notifies scheduler of non-empty batch which is eligible for processing.
+ void AddBatch(internal::ASBSBatch<TaskType>*);
+
+ // Removes queue from scheduler.
+ void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
+
+ Env* GetEnv() const { return options_.env; }
+
+ const Options options_;
+
+ struct BatchCompare {
+ bool operator()(const internal::ASBSBatch<TaskType>* a,
+ const internal::ASBSBatch<TaskType>* b);
+ };
+
+ // Collection of batches added by AddBatch, ordered by age. Owned by scheduler
+ // until they are released for processing.
+ std::priority_queue<const internal::ASBSBatch<TaskType>*,
+ std::vector<internal::ASBSBatch<TaskType>*>, BatchCompare>
+ batches_ GUARDED_BY(mu_);
+
+ // Unowned queues and callbacks added by AddQueue.
+ std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
+ queues_and_callbacks_ GUARDED_BY(mu_);
+
+ mutex mu_;
+
+ // Responsible for running ProcessOneBatch. PeriodicFunction was used in order
+ // to check for deletion so that the thread can be shut down.
+ std::unique_ptr<PeriodicFunction> scheduling_thread_;
+
+ // Responsible for running the batch processing callbacks.
+ std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
+
+ // Time interval in microseconds between successive ProcessOneBatch calls.
+ double scheduling_period_;
+
+ // Exponentially weighted moving average of
+ // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch
+ // call.
+ double ewma_feedback_ = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
+};
+
+//////////////////////////////////////////////////////////
+// Implementation details follow. API users need not read.
+
+namespace internal {
+// Consolidates tasks into batches, passing them off to the
+// AdaptiveSharedBatchScheduler for processing.
+template <typename TaskType>
+class ASBSQueue : public BatchScheduler<TaskType> {
+ public:
+ using QueueOptions =
+ typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
+
+ ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
+ const QueueOptions& options);
+
+ ~ASBSQueue() override;
+
+ // Adds task to current batch. Fails if the task size is larger than the batch
+ // size or if the current batch is full and this queue's number of outstanding
+ // batches is at its maximum.
+ Status Schedule(std::unique_ptr<TaskType>* task) override;
+
+ // Number of tasks waiting to be scheduled.
+ size_t NumEnqueuedTasks() const override;
+
+ // Number of size 1 tasks which could currently be scheduled without failing.
+ size_t SchedulingCapacity() const override;
+
+ // Notifies queue that a batch is about to be scheduled; the queue should not
+ // place any more tasks in this batch.
+ void ReleaseBatch(const ASBSBatch<TaskType>* batch);
+
+ private:
+ std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
+ const QueueOptions options_;
+ // Owned by scheduler_.
+ ASBSBatch<TaskType>* current_batch_ GUARDED_BY(mu_) = nullptr;
+ int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0;
+ int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0;
+ mutable mutex mu_;
+ TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
+};
+
+// Batch which remembers when and by whom it was created.
+template <typename TaskType>
+class ASBSBatch : public Batch<TaskType> {
+ public:
+ ASBSBatch(ASBSQueue<TaskType>* queue, int64 creation_time_micros)
+ : queue_(queue), creation_time_micros_(creation_time_micros) {}
+
+ ~ASBSBatch() override {}
+
+ ASBSQueue<TaskType>* queue() const { return queue_; }
+
+ int64 creation_time_micros() const { return creation_time_micros_; }
+
+ private:
+ ASBSQueue<TaskType>* queue_;
+ const int64 creation_time_micros_;
+ TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
+};
+} // namespace internal
+
+// ---------------- AdaptiveSharedBatchScheduler ----------------
+
+template <typename TaskType>
+Status AdaptiveSharedBatchScheduler<TaskType>::Create(
+ const Options& options,
+ std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
+ if (options.num_batch_threads < 1) {
+ return errors::InvalidArgument("num_batch_threads must be positive; was ",
+ options.num_batch_threads);
+ }
+ if (options.min_scheduling_period_micros < 0) {
+ return errors::InvalidArgument(
+ "min_scheduling_period_micros must be >= 0; was ",
+ options.min_scheduling_period_micros);
+ }
+ if (options.min_scheduling_period_micros >
+ options.initial_scheduling_period_micros) {
+ return errors::InvalidArgument(
+ "initial_scheduling_period_micros (",
+ options.initial_scheduling_period_micros,
+ ") must be >= min_scheduling_period_micros (",
+ options.min_scheduling_period_micros, ")");
+ }
+ if (options.initial_scheduling_period_micros >
+ options.max_scheduling_period_micros) {
+ return errors::InvalidArgument(
+ "initial_scheduling_period_micros (",
+ options.initial_scheduling_period_micros,
+ ") must be <= max_scheduling_period_micros (",
+ options.max_scheduling_period_micros, ")");
+ }
+ if (options.feedback_smoothing_batches < 1) {
+ return errors::InvalidArgument(
+ "feedback_smoothing_batches must be positive; was ",
+ options.feedback_smoothing_batches);
+ }
+ scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
+ return Status::OK();
+}
+
+template <typename TaskType>
+AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
+ const Options& options)
+ : options_(options),
+ scheduling_period_(options.initial_scheduling_period_micros) {
+ PeriodicFunction::Options opts;
+ opts.thread_name_prefix = "scheduling_thread";
+ opts.env = GetEnv();
+ scheduling_thread_.reset(
+ new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts));
+ batch_thread_pool_.reset(new thread::ThreadPool(
+ GetEnv(), options.thread_pool_name, options.num_batch_threads));
+}
+
+template <typename TaskType>
+Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
+ const QueueOptions& options, BatchProcessor process_batch_callback,
+ std::unique_ptr<BatchScheduler<TaskType>>* queue) {
+ if (options.max_batch_size <= 0) {
+ return errors::InvalidArgument("max_batch_size must be positive; was ",
+ options.max_batch_size);
+ }
+ if (options.max_enqueued_batches <= 0) {
+ return errors::InvalidArgument(
+ "max_enqueued_batches must be positive; was ",
+ options.max_enqueued_batches);
+ }
+ internal::ASBSQueue<TaskType>* asbs_queue_raw;
+ queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
+ this->shared_from_this(), options));
+ mutex_lock l(mu_);
+ queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
+ return Status::OK();
+}
+
+template <typename TaskType>
+void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
+ internal::ASBSBatch<TaskType>* batch) {
+ mutex_lock l(mu_);
+ batches_.push(batch);
+}
+
+template <typename TaskType>
+void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
+ const internal::ASBSQueue<TaskType>* queue) {
+ mutex_lock l(mu_);
+ queues_and_callbacks_.erase(queue);
+}
+
+template <typename TaskType>
+void AdaptiveSharedBatchScheduler<TaskType>::ProcessOneBatch() {
+ static const double kFeedbackMultiplier = .001;
+ internal::ASBSBatch<TaskType>* batch = nullptr;
+ BatchProcessor callback;
+ const int64 start_time_micros = GetEnv()->NowMicros();
+ {
+ mutex_lock l(mu_);
+ if (!batches_.empty()) {
+ batch = batches_.top();
+ batches_.pop();
+ callback = queues_and_callbacks_[batch->queue()];
+ }
+ }
+ if (batch != nullptr) {
+ double feedback = options_.scheduling_period_feedback();
+ const int64 N = options_.feedback_smoothing_batches;
+ ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N;
+ scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_);
+ if (scheduling_period_ < options_.min_scheduling_period_micros) {
+ scheduling_period_ = options_.min_scheduling_period_micros;
+ } else if (scheduling_period_ > options_.max_scheduling_period_micros) {
+ scheduling_period_ = options_.max_scheduling_period_micros;
+ }
+ // Queue may destroy itself after ReleaseBatch is called.
+ batch->queue()->ReleaseBatch(batch);
+ batch_thread_pool_->Schedule([callback, batch] {
+ callback(std::unique_ptr<Batch<TaskType>>(batch));
+ });
+ }
+ const int64 sleep_time =
+ scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros);
+ if (sleep_time > 0) {
+ GetEnv()->SleepForMicroseconds(sleep_time);
+ }
+}
+
+template <typename TaskType>
+bool AdaptiveSharedBatchScheduler<TaskType>::BatchCompare::operator()(
+ const internal::ASBSBatch<TaskType>* a,
+ const internal::ASBSBatch<TaskType>* b) {
+ return a->creation_time_micros() > b->creation_time_micros();
+}
+
+// ---------------- ASBSQueue ----------------
+
+namespace internal {
+template <typename TaskType>
+ASBSQueue<TaskType>::ASBSQueue(
+ std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
+ const QueueOptions& options)
+ : scheduler_(scheduler), options_(options) {}
+
+template <typename TaskType>
+ASBSQueue<TaskType>::~ASBSQueue() {
+ // Wait until last batch has been scheduled.
+ const int kSleepMicros = 1000;
+ for (;;) {
+ {
+ mutex_lock l(mu_);
+ if (num_enqueued_batches_ == 0) {
+ break;
+ }
+ }
+ scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
+ }
+ scheduler_->RemoveQueue(this);
+}
+
+template <typename TaskType>
+Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
+ bool added_new_batch = false;
+ size_t size = (*task)->size();
+ if (size > options_.max_batch_size) {
+ return errors::InvalidArgument("Task size ", size,
+ " is larger than maximum batch size ",
+ options_.max_batch_size);
+ }
+ {
+ mutex_lock l(mu_);
+ // Current batch is full, create another if allowed.
+ if (current_batch_ &&
+ current_batch_->size() + size > options_.max_batch_size) {
+ if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
+ return errors::Unavailable("The batch scheduling queue is full");
+ }
+ current_batch_->Close();
+ current_batch_ = nullptr;
+ }
+ if (!current_batch_) {
+ added_new_batch = true;
+ num_enqueued_batches_++;
+ current_batch_ =
+ new ASBSBatch<TaskType>(this, scheduler_->GetEnv()->NowMicros());
+ }
+ current_batch_->AddTask(std::move(*task));
+ num_enqueued_tasks_++;
+ }
+ if (added_new_batch) scheduler_->AddBatch(current_batch_);
+ return Status::OK();
+}
+
+template <typename TaskType>
+void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
+ mutex_lock l(mu_);
+ num_enqueued_batches_--;
+ num_enqueued_tasks_ -= batch->num_tasks();
+ if (batch == current_batch_) {
+ current_batch_->Close();
+ current_batch_ = nullptr;
+ }
+}
+
+template <typename TaskType>
+size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
+ mutex_lock l(mu_);
+ return num_enqueued_tasks_;
+}
+
+template <typename TaskType>
+size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
+ mutex_lock l(mu_);
+ const int current_batch_capacity =
+ current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
+ const int spare_batches =
+ options_.max_enqueued_batches - num_enqueued_batches_;
+ return spare_batches * options_.max_batch_size + current_batch_capacity;
+}
+} // namespace internal
+} // namespace serving
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
new file mode 100644
index 0000000000..a07cd6d834
--- /dev/null
+++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
@@ -0,0 +1,438 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h"
+
+#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace serving {
+namespace anonymous {
+
+class FakeTask : public BatchTask {
+ public:
+ explicit FakeTask(size_t size) : size_(size) {}
+
+ ~FakeTask() override = default;
+
+ size_t size() const override { return size_; }
+
+ private:
+ const size_t size_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
+};
+
+// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
+// that task. Returns the resulting status.
+Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
+ std::unique_ptr<FakeTask> task(new FakeTask(task_size));
+ Status status = scheduler->Schedule(&task);
+ // Schedule() should have consumed 'task' iff it returned Status::OK.
+ CHECK_EQ(status.ok(), task == nullptr);
+ return status;
+}
+
+// Creates a thread that waits on 'start' and then advances the fake clock in
+// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
+// use the clock to be destroyed.
+std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
+ test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
+ return std::unique_ptr<Thread>(Env::Default()->StartThread(
+ {}, "FakeClockAdvancerThread", [env, start, stop] {
+ start->WaitForNotification();
+ while (!stop->HasBeenNotified()) {
+ env->AdvanceByMicroseconds(10);
+ Env::Default()->SleepForMicroseconds(10);
+ }
+ }));
+}
+
+TEST(AdaptiveSharedBatchSchedulerTest, Basic) {
+ for (const bool delete_scheduler_early : {false, true}) {
+ for (const bool delete_queue_1_early : {false, true}) {
+ int queue_0_tasks = 0;
+ auto queue_0_callback =
+ [&queue_0_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ for (int i = 0; i < batch->num_tasks(); i++) {
+ queue_0_tasks += batch->task(i).size();
+ }
+ };
+ int queue_1_tasks = 0;
+ auto queue_1_callback =
+ [&queue_1_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ for (int i = 0; i < batch->num_tasks(); i++) {
+ queue_1_tasks += batch->task(i).size();
+ }
+ };
+ {
+ std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ AdaptiveSharedBatchScheduler<FakeTask>::Create({}, &scheduler));
+
+ // Create two queues.
+ std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1));
+
+ if (delete_scheduler_early) {
+ // Delete our copy of the scheduler. The queues should keep it alive
+ // under the covers.
+ scheduler = nullptr;
+ }
+ // Submit tasks to the two queues, and (optionally) remove the queues.
+ TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
+ TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
+ TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
+ TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
+ if (delete_queue_1_early) {
+ queue_1 = nullptr;
+ }
+ TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
+ }
+ EXPECT_EQ(queue_0_tasks, 9);
+ EXPECT_EQ(queue_1_tasks, 6);
+ }
+ }
+}
+
+TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) {
+ using Scheduler = AdaptiveSharedBatchScheduler<FakeTask>;
+ std::shared_ptr<Scheduler> scheduler;
+ Scheduler::Options options;
+ options.num_batch_threads = 0;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = Scheduler::Options();
+ options.min_scheduling_period_micros = 50;
+ options.max_scheduling_period_micros = 100;
+ options.initial_scheduling_period_micros = 1;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = Scheduler::Options();
+ options.min_scheduling_period_micros = 50;
+ options.max_scheduling_period_micros = 100;
+ options.initial_scheduling_period_micros = 1000;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = Scheduler::Options();
+ options.min_scheduling_period_micros = 100;
+ options.max_scheduling_period_micros = 50;
+ options.initial_scheduling_period_micros = 75;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = Scheduler::Options();
+ options.feedback_smoothing_batches = 0;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+}
+
+TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) {
+ test_util::FakeClockEnv env(Env::Default());
+ Notification start_teardown, stop_teardown;
+ std::unique_ptr<Thread> teardown_thread =
+ CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
+ {
+ AdaptiveSharedBatchScheduler<FakeTask>::Options options;
+ options.initial_scheduling_period_micros = 1000;
+ options.env = &env;
+ std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
+ int queue_0_tasks = 0;
+ int queue_1_tasks = 0;
+ auto queue_0_callback = [&queue_0_tasks,
+ &env](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ for (int i = 0; i < batch->num_tasks(); i++) {
+ queue_0_tasks += batch->task(i).size();
+ }
+ env.SleepForMicroseconds(1);
+ };
+ auto queue_1_callback = [&queue_1_tasks,
+ &env](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ for (int i = 0; i < batch->num_tasks(); i++) {
+ queue_1_tasks += batch->task(i).size();
+ }
+ env.SleepForMicroseconds(1);
+ };
+ AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
+ queue_options.max_batch_size = 10;
+ queue_options.max_enqueued_batches = 0;
+ // Queue must have max_enqueued_batchs > 1.
+ EXPECT_FALSE(
+ scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok());
+ queue_options.max_enqueued_batches = 2;
+ TF_ASSERT_OK(
+ scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
+ queue_options.max_batch_size = 0;
+ // Queue must have max_batch_size > 0.
+ EXPECT_FALSE(
+ scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok());
+ queue_options.max_batch_size = 2;
+ queue_options.max_enqueued_batches = 1;
+ TF_ASSERT_OK(
+ scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
+
+ // Wait for scheduling_thread to sleep.
+ env.BlockUntilThreadsAsleep(1);
+ // Task larger than max_batch_size shouldn't schedule.
+ EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok());
+ TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
+ TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
+ env.AdvanceByMicroseconds(1);
+
+ // Task larger than max_batch_size shouldn't schedule.
+ EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok());
+ TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
+ TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
+ env.AdvanceByMicroseconds(1);
+ // Exceeds max_enqueued_batches, shouldn't schedule.
+ EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok());
+
+ TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
+ // Exceeds max_enqueued_batches, shouldn't schedule.
+ EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok());
+ TF_ASSERT_OK(ScheduleTask(4, queue_0.get()));
+
+ // Batches should be processed in order from oldest to newest.
+ env.AdvanceByMicroseconds(1000);
+ env.BlockUntilThreadsAsleep(2);
+ EXPECT_EQ(queue_0_tasks, 10);
+ EXPECT_EQ(queue_1_tasks, 0);
+
+ env.AdvanceByMicroseconds(1000);
+ env.BlockUntilThreadsAsleep(2);
+ EXPECT_EQ(queue_0_tasks, 10);
+ EXPECT_EQ(queue_1_tasks, 2);
+
+ env.AdvanceByMicroseconds(1000);
+ env.BlockUntilThreadsAsleep(2);
+ EXPECT_EQ(queue_0_tasks, 19);
+ EXPECT_EQ(queue_1_tasks, 2);
+ start_teardown.Notify();
+ }
+ stop_teardown.Notify();
+}
+
+TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) {
+ test_util::FakeClockEnv env(Env::Default());
+ Notification start_teardown, stop_teardown;
+ std::unique_ptr<Thread> teardown_thread =
+ CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
+ {
+ double feedback = 0;
+ AdaptiveSharedBatchScheduler<FakeTask>::Options options;
+ options.initial_scheduling_period_micros = 1000;
+ options.min_scheduling_period_micros = 200;
+ options.max_scheduling_period_micros = 2000;
+ options.env = &env;
+ options.scheduling_period_feedback = [&feedback] { return feedback; };
+ options.feedback_smoothing_batches = 1;
+ std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue;
+ int scheduled_items = 0;
+ auto queue_callback = [&scheduled_items,
+ &env](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ scheduled_items = 0;
+ for (int i = 0; i < batch->num_tasks(); i++) {
+ scheduled_items += batch->task(i).size();
+ }
+ env.SleepForMicroseconds(1);
+ };
+
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
+
+ // Wait for scheduling_thread to sleep.
+ env.BlockUntilThreadsAsleep(1);
+ // Enqueue 6 batches.
+ for (int i = 0; i < 6; i++) {
+ TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
+ env.AdvanceByMicroseconds(1);
+ }
+ feedback = -500;
+ env.AdvanceByMicroseconds(994);
+ env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec.
+ EXPECT_EQ(scheduled_items, 900);
+ env.AdvanceByMicroseconds(500);
+ env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
+ EXPECT_EQ(scheduled_items, 901);
+ feedback = 0;
+ env.AdvanceByMicroseconds(250);
+ env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
+ EXPECT_EQ(scheduled_items, 902);
+ feedback = 10000; // large feedback should hit max_scheduling_period.
+ env.AdvanceByMicroseconds(250);
+ env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec.
+ EXPECT_EQ(scheduled_items, 903);
+ feedback = -10000; // large feedback should hit min_scheduling_period.
+ env.AdvanceByMicroseconds(1999);
+ // No callback scheduled, only scheduling thread sleeping.
+ env.BlockUntilThreadsAsleep(1);
+ EXPECT_EQ(scheduled_items, 903);
+ env.AdvanceByMicroseconds(1);
+ env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec.
+ EXPECT_EQ(scheduled_items, 904);
+ env.AdvanceByMicroseconds(200);
+ env.BlockUntilThreadsAsleep(2);
+ EXPECT_EQ(scheduled_items, 905);
+ start_teardown.Notify();
+ }
+ stop_teardown.Notify();
+}
+
+TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) {
+ test_util::FakeClockEnv env(Env::Default());
+ Notification start_teardown, stop_teardown;
+ std::unique_ptr<Thread> teardown_thread =
+ CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
+ {
+ double feedback = 0;
+ AdaptiveSharedBatchScheduler<FakeTask>::Options options;
+ options.initial_scheduling_period_micros = 1000;
+ options.env = &env;
+ options.scheduling_period_feedback = [&feedback] { return feedback; };
+ options.feedback_smoothing_batches = 3;
+ std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue;
+ int scheduled_items = 0;
+ auto queue_callback = [&scheduled_items,
+ &env](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ scheduled_items = 0;
+ for (int i = 0; i < batch->num_tasks(); i++) {
+ scheduled_items += batch->task(i).size();
+ }
+ env.SleepForMicroseconds(1);
+ };
+
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
+
+ // Wait for scheduling_thread to sleep.
+ env.BlockUntilThreadsAsleep(1);
+ // Enqueue 4 batches.
+ for (int i = 0; i < 4; i++) {
+ TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
+ env.AdvanceByMicroseconds(1);
+ }
+ feedback = -300;
+ env.AdvanceByMicroseconds(996);
+ env.BlockUntilThreadsAsleep(2);
+ // ewma_feedback = 100, scheduling_period = 900.
+ EXPECT_EQ(scheduled_items, 900);
+ env.AdvanceByMicroseconds(899);
+ // No callback scheduled, only scheduling thread sleeping.
+ env.BlockUntilThreadsAsleep(1);
+ EXPECT_EQ(scheduled_items, 900);
+ env.AdvanceByMicroseconds(1);
+ env.BlockUntilThreadsAsleep(2);
+ // ewma_feedback = 167, scheduling_period = 750.
+ EXPECT_EQ(scheduled_items, 901);
+ env.AdvanceByMicroseconds(749);
+ // No callback scheduled, only scheduling thread sleeping.
+ env.BlockUntilThreadsAsleep(1);
+ EXPECT_EQ(scheduled_items, 901);
+ feedback = 1000 / 3.;
+ env.AdvanceByMicroseconds(1);
+ env.BlockUntilThreadsAsleep(2);
+ // emwa_feedback = 0, scheduling_period = 750.
+ EXPECT_EQ(scheduled_items, 902);
+ env.AdvanceByMicroseconds(749);
+ // No callback scheduled, only scheduling thread sleeping.
+ env.BlockUntilThreadsAsleep(1);
+ EXPECT_EQ(scheduled_items, 902);
+ env.AdvanceByMicroseconds(1);
+ env.BlockUntilThreadsAsleep(2);
+ EXPECT_EQ(scheduled_items, 903);
+ start_teardown.Notify();
+ }
+ stop_teardown.Notify();
+}
+
+TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) {
+ test_util::FakeClockEnv env(Env::Default());
+ Notification start_teardown, stop_teardown;
+ std::unique_ptr<Thread> teardown_thread =
+ CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
+ {
+ AdaptiveSharedBatchScheduler<FakeTask>::Options options;
+ options.initial_scheduling_period_micros = 1000;
+ options.env = &env;
+ std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue;
+ int scheduled_items = 0;
+ auto queue_callback = [&scheduled_items,
+ &env](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ scheduled_items = 0;
+ for (int i = 0; i < batch->num_tasks(); i++) {
+ scheduled_items += batch->task(i).size();
+ }
+ env.SleepForMicroseconds(1);
+ };
+ AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
+ queue_options.max_batch_size = 10;
+ queue_options.max_enqueued_batches = 10;
+ TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue));
+
+ // Wait for scheduling_thread to sleep.
+ env.BlockUntilThreadsAsleep(1);
+ // Enqueue 3 tasks.
+ EXPECT_EQ(queue->NumEnqueuedTasks(), 0);
+ EXPECT_EQ(queue->SchedulingCapacity(), 100);
+ TF_ASSERT_OK(ScheduleTask(5, queue.get()));
+ EXPECT_EQ(queue->NumEnqueuedTasks(), 1);
+ EXPECT_EQ(queue->SchedulingCapacity(), 95);
+ env.AdvanceByMicroseconds(1);
+ TF_ASSERT_OK(ScheduleTask(6, queue.get()));
+ EXPECT_EQ(queue->NumEnqueuedTasks(), 2);
+ EXPECT_EQ(queue->SchedulingCapacity(), 84);
+ env.AdvanceByMicroseconds(1);
+ TF_ASSERT_OK(ScheduleTask(1, queue.get()));
+ EXPECT_EQ(queue->NumEnqueuedTasks(), 3);
+ EXPECT_EQ(queue->SchedulingCapacity(), 83);
+
+ env.AdvanceByMicroseconds(998);
+ env.BlockUntilThreadsAsleep(2);
+ EXPECT_EQ(scheduled_items, 5);
+ env.AdvanceByMicroseconds(1000);
+ env.BlockUntilThreadsAsleep(2);
+ EXPECT_EQ(scheduled_items, 7);
+ start_teardown.Notify();
+ }
+ stop_teardown.Notify();
+}
+} // namespace anonymous
+} // namespace serving
+} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h
index 7c41ad8818..a5072f439a 100644
--- a/tensorflow/contrib/batching/batch_scheduler.h
+++ b/tensorflow/contrib/batching/batch_scheduler.h
@@ -78,7 +78,7 @@ template <typename TaskType>
class Batch {
public:
Batch() = default;
- ~Batch(); // Blocks until the batch is closed.
+ virtual ~Batch(); // Blocks until the batch is closed.
// Appends 'task' to the batch. After calling AddTask(), the newly-added task
// can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).