aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/batching_util
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-30 16:27:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 16:29:42 -0700
commit1e007dfddd5c20f89300a2e3669f56db47e2154c (patch)
treeb32fc17a11f4fb3a485ac28dd2eb58b9410671e6 /tensorflow/core/kernels/batching_util
parentc9297e34f0ceef4afd970ee117aea9110bf8ae62 (diff)
Add SerialDeviceBatchScheduler which offers similar performance as the AdaptiveSharedBatchScheduler, but increased reliablility and stability.
ASBS assumes request latency can be minimized at a specific number of batch processing threads. Under reasonable load, this is true and ASBS performs well, but under low load latency is basically unaffected by the number of threads, and ASBS can learn a wide variety of 'optimal' values. If load resumes suddenly, these values can give very poor latencies. In most cases, ASBS will recover, eventually rediscovering the correct value, but we have observed other cases where the latency is so large and noisy that ASBS can't get a good signal to guide its learning and the number of threads remains stuck at the bad value. In addition, the incremental learning nature of this algorithm means that ASBS is always exploring to some extent, which can give rise to periods of non-optimal latency. This is most significant at high utilization where the wrong number of threads can potentially overload the system. ASBS uses latency as a proxy for keeping the tensorflow processing pipeline optimally loaded. SDBS, on the other hand, uses a direct measurement of the pipeline fullness, and adjusts its number of batch processing threads accordingly. This solves the exploration problem. SDBS solves the low load problem by not adjusting its thread count when the threads pass some idleness threshold. PiperOrigin-RevId: 198638918
Diffstat (limited to 'tensorflow/core/kernels/batching_util')
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD21
-rw-r--r--tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h548
-rw-r--r--tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc394
3 files changed, 963 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index de05c647d6..e292ff200a 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -127,6 +127,27 @@ tf_cc_test(
)
cc_library(
+ name = "serial_device_batch_scheduler",
+ hdrs = ["serial_device_batch_scheduler.h"],
+ deps = [
+ ":batch_scheduler",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "serial_device_batch_scheduler_test",
+ srcs = ["serial_device_batch_scheduler_test.cc"],
+ deps = [
+ ":fake_clock_env",
+ ":serial_device_batch_scheduler",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "basic_batch_scheduler",
hdrs = ["basic_batch_scheduler.h"],
deps = [
diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h
new file mode 100644
index 0000000000..518f2ff8a9
--- /dev/null
+++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h
@@ -0,0 +1,548 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
+
+#include <algorithm>
+#include <functional>
+#include <memory>
+#include <random>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace serving {
+namespace internal {
+template <typename TaskType>
+class SDBSBatch;
+
+template <typename TaskType>
+class SDBSQueue;
+} // namespace internal
+
+// EXPERIMENTAL: API MAY BE SUBJECTED TO SUDDEN CHANGES.
+//
+// Shared batch scheduler designed for batches which are processed by a serial
+// device (e.g. GPU, TPU). When batch processing involves a mix of
+// parallelizable cpu work and non-parallelizable on-device work, overall
+// latency can be minimized by producing batches at a (load dependent) rate
+// which keeps the serial device uniformly busy.
+//
+// SerialDeviceBatchScheduler (SDBS) controls the batching rate by limiting the
+// allowed number of concurrently processed batches. Too large a limit causes
+// batches to pile up behind the serial device, adding to the overall batch
+// latency. Too small a limit underutilizes the serial device and harms latency
+// by forcing batches to wait longer to be processed. Feedback from the device
+// (i.e. avg number of batches directly pending on the device) is used to set
+// the correct limit.
+//
+// SDBS groups requests into per model batches which are processed when a batch
+// processing thread becomes available. SDBS prioritizes batches primarily by
+// age (i.e. the batch's oldest request) along with a configurable preference
+// for scheduling larger batches first.
+
+
+template <typename TaskType>
+class SerialDeviceBatchScheduler : public std::enable_shared_from_this<
+ SerialDeviceBatchScheduler<TaskType>> {
+ public:
+ ~SerialDeviceBatchScheduler();
+
+ struct Options {
+ // The name to use for the pool of batch threads.
+ string thread_pool_name = {"batch_threads"};
+ // Maximum number of batch processing threads.
+ int64 num_batch_threads = port::NumSchedulableCPUs();
+ // Although batch selection is primarily based on age, this parameter
+ // specifies a preference for larger batches. A full batch will be
+ // scheduled before an older, nearly empty batch as long as the age gap is
+ // less than full_batch_scheduling_boost_micros. The optimal value for this
+ // parameter should be of order the batch processing latency, but must be
+ // chosen carefully, as too large a value will harm tail latency.
+ int64 full_batch_scheduling_boost_micros = 0;
+ // The environment to use (typically only overridden by test code).
+ Env* env = Env::Default();
+ // Initial limit for number of batches being concurrently processed.
+ int64 initial_in_flight_batches_limit = 3;
+ // Returns the current number of batches directly waiting to be processed
+ // by the serial device (i.e. GPU, TPU).
+ std::function<int64()> get_pending_on_serial_device;
+ // Desired average number of batches directly waiting to be processed by the
+ // serial device. Small numbers of O(1) should deliver the best latency.
+ double target_pending = 2;
+ // Number of batches between potential adjustments of
+ // in_flight_batches_limit. Larger numbers will reduce noise, but will be
+ // less responsive to sudden changes in workload.
+ int64 batches_to_average_over = 1000;
+ };
+
+ // Ownership is shared between the caller of Create() and any queues created
+ // via AddQueue().
+ static Status Create(
+ const Options& options,
+ std::shared_ptr<SerialDeviceBatchScheduler<TaskType>>* scheduler);
+
+ struct QueueOptions {
+ // Maximum size of each batch.
+ int max_batch_size = 1000;
+ // Maximum number of enqueued (i.e. non-scheduled) batches.
+ int max_enqueued_batches = 10;
+ };
+
+ using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
+
+ // Adds queue (and its callback) to be managed by this scheduler.
+ Status AddQueue(const QueueOptions& options,
+ BatchProcessor process_batch_callback,
+ std::unique_ptr<BatchScheduler<TaskType>>* queue);
+
+ double in_flight_batches_limit() {
+ mutex_lock l(mu_);
+ return in_flight_batches_limit_;
+ }
+
+ double recent_low_traffic_ratio() {
+ mutex_lock l(mu_);
+ return recent_low_traffic_ratio_;
+ }
+
+ private:
+ // access to AddBatch(), RemoveQueue(), env().
+ friend class internal::SDBSQueue<TaskType>;
+
+ explicit SerialDeviceBatchScheduler(const Options& options);
+
+ // Continuously retrieves and processes batches.
+ void ProcessBatches();
+
+ // Notifies scheduler of non-empty batch which is eligible for processing.
+ void AddBatch(const internal::SDBSBatch<TaskType>* batch);
+
+ // Removes queue from scheduler.
+ void RemoveQueue(const internal::SDBSQueue<TaskType>* queue);
+
+ Env* env() const { return options_.env; }
+
+ const Options options_;
+
+ // Collection of batches added by AddBatch. Owned by scheduler until they are
+ // released for processing.
+ std::vector<const internal::SDBSBatch<TaskType>*> batches_ GUARDED_BY(mu_);
+
+ // Unowned queues and callbacks added by AddQueue.
+ std::unordered_map<const internal::SDBSQueue<TaskType>*, BatchProcessor>
+ queues_and_callbacks_ GUARDED_BY(mu_);
+
+ // Responsible for running the batch processing callbacks.
+ std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
+
+ // Limit on number of batches which can be concurrently processed.
+ int64 in_flight_batches_limit_ GUARDED_BY(mu_);
+
+ // Number of batch processing threads.
+ int64 processing_threads_ GUARDED_BY(mu_) = 0;
+
+ // Number of batches processed since the last in_flight_batches_limit_
+ // adjustment.
+ int64 batch_count_ GUARDED_BY(mu_) = 0;
+
+ // Number of times since the last in_flight_batches_limit_ adjustment when a
+ // processing thread was available but there were no batches to process.
+ int64 no_batch_count_ GUARDED_BY(mu_) = 0;
+
+ // Sum of batches pending on the serial device since the last
+ // in_flight_batches_limit_ adjustment.
+ int64 pending_sum_ = 0;
+
+ // Sum of batch latencies since the last in_flight_batches_limit_ adjustment.
+ int64 batch_latency_sum_ = 0;
+
+ // Average period between which two consecutive batches begin processing.
+ int64 batch_period_micros_ = 0;
+
+ // Moving average tracking the fraction of recent in_flight_batches_limit_
+ // adjustments where the external traffic was not high enough to provide
+ // useful feedback for an adjustment.
+ double recent_low_traffic_ratio_ = 0;
+
+ mutex mu_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SerialDeviceBatchScheduler);
+};
+
+//////////////////////////////////////////////////////////
+// Implementation details follow. API users need not read.
+
+namespace internal {
+// Consolidates tasks into batches, passing them off to the
+// SerialDeviceBatchScheduler for processing.
+template <typename TaskType>
+class SDBSQueue : public BatchScheduler<TaskType> {
+ public:
+ using QueueOptions =
+ typename SerialDeviceBatchScheduler<TaskType>::QueueOptions;
+
+ SDBSQueue(std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,
+ const QueueOptions& options);
+
+ ~SDBSQueue() override;
+
+ // Adds task to current batch. Fails if the task size is larger than the batch
+ // size or if the current batch is full and this queue's number of outstanding
+ // batches is at its maximum.
+ Status Schedule(std::unique_ptr<TaskType>* task) override;
+
+ // Number of tasks waiting to be scheduled.
+ size_t NumEnqueuedTasks() const override;
+
+ // Number of size 1 tasks which could currently be scheduled without failing.
+ size_t SchedulingCapacity() const override;
+
+ // Notifies queue that a batch is about to be scheduled; the queue should not
+ // place any more tasks in this batch.
+ void ReleaseBatch(const SDBSBatch<TaskType>* batch);
+
+ size_t max_task_size() const override { return options_.max_batch_size; }
+
+ private:
+ std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler_;
+ const QueueOptions options_;
+ // Owned by scheduler_.
+ SDBSBatch<TaskType>* current_batch_ GUARDED_BY(mu_) = nullptr;
+ int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0;
+ int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0;
+ mutable mutex mu_;
+ TF_DISALLOW_COPY_AND_ASSIGN(SDBSQueue);
+};
+
+// Batch which remembers when and by whom it was created.
+template <typename TaskType>
+class SDBSBatch : public Batch<TaskType> {
+ public:
+ SDBSBatch(SDBSQueue<TaskType>* queue, int64 creation_time_micros)
+ : queue_(queue), creation_time_micros_(creation_time_micros) {}
+
+ ~SDBSBatch() override {}
+
+ SDBSQueue<TaskType>* queue() const { return queue_; }
+
+ int64 creation_time_micros() const { return creation_time_micros_; }
+
+ private:
+ SDBSQueue<TaskType>* queue_;
+ const int64 creation_time_micros_;
+ TF_DISALLOW_COPY_AND_ASSIGN(SDBSBatch);
+};
+} // namespace internal
+
+// ---------------- SerialDeviceBatchScheduler ----------------
+
+template <typename TaskType>
+Status SerialDeviceBatchScheduler<TaskType>::Create(
+ const Options& options,
+ std::shared_ptr<SerialDeviceBatchScheduler<TaskType>>* scheduler) {
+ if (options.num_batch_threads < 1) {
+ return errors::InvalidArgument("num_batch_threads must be positive; was ",
+ options.num_batch_threads);
+ }
+ if (options.initial_in_flight_batches_limit < 1) {
+ return errors::InvalidArgument(
+ "initial_in_flight_batches_limit must be positive; was ",
+ options.initial_in_flight_batches_limit);
+ }
+ if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
+ return errors::InvalidArgument(
+ "initial_in_flight_batches_limit (",
+ options.initial_in_flight_batches_limit,
+ ") should not be larger than num_batch_threads (",
+ options.num_batch_threads, ")");
+ }
+ if (options.full_batch_scheduling_boost_micros < 0) {
+ return errors::InvalidArgument(
+ "full_batch_scheduling_boost_micros can't be negative; was ",
+ options.full_batch_scheduling_boost_micros);
+ }
+ if (options.batches_to_average_over < 1) {
+ return errors::InvalidArgument(
+ "batches_to_average_over should be "
+ "greater than or equal to 1; was ",
+ options.batches_to_average_over);
+ }
+ if (options.target_pending <= 0) {
+ return errors::InvalidArgument(
+ "target_pending should be larger than zero; was ",
+ options.target_pending);
+ }
+ if (!options.get_pending_on_serial_device) {
+ return errors::InvalidArgument(
+ "get_pending_on_serial_device must be "
+ "specified");
+ }
+ scheduler->reset(new SerialDeviceBatchScheduler<TaskType>(options));
+ return Status::OK();
+}
+
+template <typename TaskType>
+SerialDeviceBatchScheduler<TaskType>::SerialDeviceBatchScheduler(
+ const Options& options)
+ : options_(options),
+ in_flight_batches_limit_(options.initial_in_flight_batches_limit),
+ processing_threads_(options.initial_in_flight_batches_limit) {
+ batch_thread_pool_.reset(new thread::ThreadPool(
+ env(), options.thread_pool_name, options.num_batch_threads));
+ for (int i = 0; i < processing_threads_; i++) {
+ batch_thread_pool_->Schedule(
+ std::bind(&SerialDeviceBatchScheduler<TaskType>::ProcessBatches, this));
+ }
+}
+
+template <typename TaskType>
+SerialDeviceBatchScheduler<TaskType>::~SerialDeviceBatchScheduler() {
+ // Signal processing threads to exit.
+ {
+ mutex_lock l(mu_);
+ processing_threads_ = 0;
+ }
+ // Hangs until all threads finish.
+ batch_thread_pool_.reset();
+}
+
+template <typename TaskType>
+Status SerialDeviceBatchScheduler<TaskType>::AddQueue(
+ const QueueOptions& options, BatchProcessor process_batch_callback,
+ std::unique_ptr<BatchScheduler<TaskType>>* queue) {
+ if (options.max_batch_size <= 0) {
+ return errors::InvalidArgument("max_batch_size must be positive; was ",
+ options.max_batch_size);
+ }
+ if (options.max_enqueued_batches <= 0) {
+ return errors::InvalidArgument(
+ "max_enqueued_batches must be positive; was ",
+ options.max_enqueued_batches);
+ }
+ internal::SDBSQueue<TaskType>* SDBS_queue_raw;
+ queue->reset(SDBS_queue_raw = new internal::SDBSQueue<TaskType>(
+ this->shared_from_this(), options));
+ mutex_lock l(mu_);
+ queues_and_callbacks_[SDBS_queue_raw] = process_batch_callback;
+ return Status::OK();
+}
+
+template <typename TaskType>
+void SerialDeviceBatchScheduler<TaskType>::AddBatch(
+ const internal::SDBSBatch<TaskType>* batch) {
+ mutex_lock l(mu_);
+ batches_.push_back(batch);
+}
+
+template <typename TaskType>
+void SerialDeviceBatchScheduler<TaskType>::RemoveQueue(
+ const internal::SDBSQueue<TaskType>* queue) {
+ mutex_lock l(mu_);
+ queues_and_callbacks_.erase(queue);
+}
+
+template <typename TaskType>
+void SerialDeviceBatchScheduler<TaskType>::ProcessBatches() {
+ const int64 kIdleThreadSleepTimeMicros = 1000;
+ const double kMaxNoBatchRatio = .1;
+ const double kLowTrafficMovingAverageFactor = .1;
+ for (;;) {
+ mu_.lock();
+ if (processing_threads_ < 1 ||
+ processing_threads_ > in_flight_batches_limit_) {
+ processing_threads_--;
+ mu_.unlock();
+ break;
+ }
+ if (batches_.empty()) {
+ no_batch_count_++;
+ int64 sleep_time = batch_period_micros_ ? batch_period_micros_
+ : kIdleThreadSleepTimeMicros;
+ mu_.unlock();
+ env()->SleepForMicroseconds(sleep_time);
+ continue;
+ }
+ auto best_it = batches_.begin();
+ double best_score =
+ (*best_it)->creation_time_micros() -
+ options_.full_batch_scheduling_boost_micros * (*best_it)->size() /
+ static_cast<double>((*best_it)->queue()->max_task_size());
+ for (auto it = batches_.begin() + 1; it != batches_.end(); it++) {
+ const double score =
+ (*it)->creation_time_micros() -
+ options_.full_batch_scheduling_boost_micros * (*it)->size() /
+ static_cast<double>((*it)->queue()->max_task_size());
+ if (score < best_score) {
+ best_score = score;
+ best_it = it;
+ }
+ }
+ const internal::SDBSBatch<TaskType>* batch = *best_it;
+ batches_.erase(best_it);
+ // Queue may destroy itself after ReleaseBatch is called.
+ batch->queue()->ReleaseBatch(batch);
+ auto callback = queues_and_callbacks_[batch->queue()];
+ mu_.unlock();
+ int64 start_time = env()->NowMicros();
+ callback(std::unique_ptr<Batch<TaskType>>(
+ const_cast<internal::SDBSBatch<TaskType>*>(batch)));
+ int64 end_time = env()->NowMicros();
+ mu_.lock();
+ batch_count_++;
+ batch_latency_sum_ += end_time - start_time;
+ pending_sum_ += options_.get_pending_on_serial_device();
+ if (batch_count_ == options_.batches_to_average_over) {
+ recent_low_traffic_ratio_ *= (1 - kLowTrafficMovingAverageFactor);
+ // Only adjust in_flight_batches_limit_ if external load is large enough
+ // to consistently provide batches. Otherwise we would (mistakenly) assume
+ // that the device is underutilized because in_flight_batches_limit_ is
+ // too small.
+ if (no_batch_count_ < kMaxNoBatchRatio * batch_count_) {
+ double avg_pending = pending_sum_ / static_cast<double>(batch_count_);
+ // Avg processing time / # of concurrent batches gives the avg period
+ // between which two consecutive batches begin processing. Used to set a
+ // reasonable sleep time for idle batch processing threads.
+ batch_period_micros_ =
+ batch_latency_sum_ / batch_count_ / in_flight_batches_limit_;
+ // When the processing pipeline is consistently busy, the average number
+ // of pending batches differs from in_flight_batches_limit_ by a
+ // load-dependent offset. Adjust in_flight_batches_limit_to maintain
+ // the desired target pending.
+ in_flight_batches_limit_ +=
+ std::round(options_.target_pending - avg_pending);
+ in_flight_batches_limit_ = std::max(in_flight_batches_limit_, 1LL);
+ in_flight_batches_limit_ =
+ std::min(in_flight_batches_limit_, options_.num_batch_threads);
+ // Add extra processing threads if necessary.
+ if (processing_threads_ > 0 &&
+ processing_threads_ < in_flight_batches_limit_) {
+ int extra_threads = in_flight_batches_limit_ - processing_threads_;
+ for (int i = 0; i < extra_threads; i++) {
+ batch_thread_pool_->Schedule(std::bind(
+ &SerialDeviceBatchScheduler<TaskType>::ProcessBatches, this));
+ }
+ processing_threads_ = in_flight_batches_limit_;
+ }
+ } else {
+ recent_low_traffic_ratio_ += kLowTrafficMovingAverageFactor;
+ }
+ batch_count_ = 0;
+ no_batch_count_ = 0;
+ pending_sum_ = 0;
+ batch_latency_sum_ = 0;
+ }
+ mu_.unlock();
+ }
+}
+
+// ---------------- SDBSQueue ----------------
+
+namespace internal {
+template <typename TaskType>
+SDBSQueue<TaskType>::SDBSQueue(
+ std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,
+ const QueueOptions& options)
+ : scheduler_(scheduler), options_(options) {}
+
+template <typename TaskType>
+SDBSQueue<TaskType>::~SDBSQueue() {
+ // Wait until last batch has been scheduled.
+ const int kSleepMicros = 1000;
+ for (;;) {
+ {
+ mutex_lock l(mu_);
+ if (num_enqueued_batches_ == 0) {
+ break;
+ }
+ }
+ scheduler_->env()->SleepForMicroseconds(kSleepMicros);
+ }
+ scheduler_->RemoveQueue(this);
+}
+
+template <typename TaskType>
+Status SDBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
+ SDBSBatch<TaskType>* new_batch = nullptr;
+ size_t size = (*task)->size();
+ if (size > options_.max_batch_size) {
+ return errors::InvalidArgument("Task size ", size,
+ " is larger than maximum batch size ",
+ options_.max_batch_size);
+ }
+ {
+ mutex_lock l(mu_);
+ // Current batch is full, create another if allowed.
+ if (current_batch_ &&
+ current_batch_->size() + size > options_.max_batch_size) {
+ if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
+ return errors::Unavailable("The batch scheduling queue is full");
+ }
+ current_batch_->Close();
+ current_batch_ = nullptr;
+ }
+ if (!current_batch_) {
+ num_enqueued_batches_++;
+ current_batch_ = new_batch =
+ new SDBSBatch<TaskType>(this, scheduler_->env()->NowMicros());
+ }
+ current_batch_->AddTask(std::move(*task));
+ num_enqueued_tasks_++;
+ }
+ // AddBatch must be called outside of lock, since it may call ReleaseBatch.
+ if (new_batch != nullptr) scheduler_->AddBatch(new_batch);
+ return Status::OK();
+}
+
+template <typename TaskType>
+void SDBSQueue<TaskType>::ReleaseBatch(const SDBSBatch<TaskType>* batch) {
+ mutex_lock l(mu_);
+ num_enqueued_batches_--;
+ num_enqueued_tasks_ -= batch->num_tasks();
+ if (batch == current_batch_) {
+ current_batch_->Close();
+ current_batch_ = nullptr;
+ }
+}
+
+template <typename TaskType>
+size_t SDBSQueue<TaskType>::NumEnqueuedTasks() const {
+ mutex_lock l(mu_);
+ return num_enqueued_tasks_;
+}
+
+template <typename TaskType>
+size_t SDBSQueue<TaskType>::SchedulingCapacity() const {
+ mutex_lock l(mu_);
+ const int current_batch_capacity =
+ current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
+ const int spare_batches =
+ options_.max_enqueued_batches - num_enqueued_batches_;
+ return spare_batches * options_.max_batch_size + current_batch_capacity;
+}
+} // namespace internal
+} // namespace serving
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc
new file mode 100644
index 0000000000..a2f8f9a03e
--- /dev/null
+++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc
@@ -0,0 +1,394 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h"
+
+#include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace serving {
+namespace anonymous {
+
+class FakeTask : public BatchTask {
+ public:
+ explicit FakeTask(size_t size) : size_(size) {}
+
+ ~FakeTask() override = default;
+
+ size_t size() const override { return size_; }
+
+ private:
+ const size_t size_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
+};
+
+// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
+// that task. Returns the resulting status.
+Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
+ std::unique_ptr<FakeTask> task(new FakeTask(task_size));
+ Status status = scheduler->Schedule(&task);
+ // Schedule() should have consumed 'task' iff it returned Status::OK.
+ CHECK_EQ(status.ok(), task == nullptr);
+ return status;
+}
+
+// Creates a thread that waits on 'start' and then advances the fake clock in
+// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
+// use the clock to be destroyed.
+std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
+ test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
+ return std::unique_ptr<Thread>(Env::Default()->StartThread(
+ {}, "FakeClockAdvancerThread", [env, start, stop] {
+ start->WaitForNotification();
+ while (!stop->HasBeenNotified()) {
+ env->AdvanceByMicroseconds(10);
+ Env::Default()->SleepForMicroseconds(10);
+ }
+ }));
+}
+
+TEST(SerialDeviceBatchSchedulerTest, BadOptions) {
+ using Scheduler = SerialDeviceBatchScheduler<FakeTask>;
+ std::shared_ptr<Scheduler> scheduler;
+ Scheduler::Options default_options;
+ default_options.get_pending_on_serial_device = []() { return 0; };
+ Scheduler::Options options = default_options;
+ options.num_batch_threads = 0;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = default_options;
+ options.initial_in_flight_batches_limit = 0;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = default_options;
+ options.num_batch_threads = 5;
+ options.initial_in_flight_batches_limit = 8;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = default_options;
+ options.batches_to_average_over = -5;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = default_options;
+ options.target_pending = 0;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = Scheduler::Options();
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+}
+
+TEST(SerialDeviceBatchSchedulerTest, InFlightBatchesLimit) {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.num_batch_threads = 3;
+ options.initial_in_flight_batches_limit = 2;
+ options.batches_to_average_over = 1000;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ Notification finish_processing;
+ auto queue_callback = [&mu, &processed_batches, &finish_processing](
+ std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ mu.lock();
+ int batch_num = ++processed_batches;
+ mu.unlock();
+ if (batch_num == 2) {
+ // Give third batch a chance to process if it's going to.
+ Env::Default()->SleepForMicroseconds(1000);
+ finish_processing.Notify();
+ }
+ if (batch_num == 3) {
+ ASSERT_TRUE(finish_processing.HasBeenNotified());
+ }
+ finish_processing.WaitForNotification();
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue1;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue2;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue3;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue1));
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue2));
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue3));
+ // Create 3 batches, only 2 should be processed concurrently.
+ TF_ASSERT_OK(ScheduleTask(100, queue1.get()));
+ TF_ASSERT_OK(ScheduleTask(100, queue2.get()));
+ TF_ASSERT_OK(ScheduleTask(100, queue3.get()));
+}
+
+TEST(SerialDeviceBatchSchedulerTest, PendingOnSerialDevice) {
+ mutex mu;
+ int pending;
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.num_batch_threads = 3;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1;
+ options.target_pending = 3;
+ options.get_pending_on_serial_device = [&mu, &pending]() {
+ mutex_lock l(mu);
+ return pending;
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ // Make sure batch processing thread has gone to sleep.
+ Env::Default()->SleepForMicroseconds(1000);
+ int processed_batches = 0;
+ Notification start_processing;
+ auto queue_callback = [&mu, &processed_batches, &start_processing, &pending,
+ &scheduler](std::unique_ptr<Batch<FakeTask>> batch) {
+ // Be careful with mutex mu to avoid potential deadlock with mutex mu_
+ // held in ProcessBatch() and in_flight_batches_limit().
+ int batch_num;
+ {
+ mutex_lock l(mu);
+ batch_num = ++processed_batches;
+ }
+ switch (batch_num) {
+ case 1:
+ start_processing.WaitForNotification();
+ {
+ mutex_lock l(mu);
+ pending = 2;
+ }
+ break;
+ case 2:
+ // No batches initially --> low traffic --> no adjustment.
+ CHECK_EQ(scheduler->in_flight_batches_limit(), 1);
+ {
+ mutex_lock l(mu);
+ pending = 3;
+ }
+ break;
+ case 3:
+ // Pending at target --> no adjustment.
+ CHECK_EQ(scheduler->in_flight_batches_limit(), 1);
+ {
+ mutex_lock l(mu);
+ pending = 1;
+ }
+ break;
+ case 4:
+ // Small pending --> 2 additional threads added.
+ CHECK_EQ(scheduler->in_flight_batches_limit(), 3);
+ {
+ mutex_lock l(mu);
+ pending = 3;
+ }
+ break;
+ default:
+ break;
+ }
+ };
+ std::unique_ptr<BatchScheduler<FakeTask>> queue;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
+ // Create 4 batches.
+ for (int i = 0; i < 4; i++) {
+ TF_ASSERT_OK(ScheduleTask(800, queue.get()));
+ }
+ start_processing.Notify();
+}
+
+TEST(SerialDeviceBatchSchedulerTest, FullBatchSchedulingBoostMicros) {
+ test_util::FakeClockEnv env(Env::Default());
+ Notification start_teardown, stop_teardown;
+ std::unique_ptr<Thread> teardown_thread =
+ CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
+ {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.env = &env;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1000;
+ options.full_batch_scheduling_boost_micros = 10;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ auto queue_callback =
+ [&mu, &processed_batches](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ mutex_lock l(mu);
+ processed_batches++;
+ switch (processed_batches) {
+ case 1:
+ EXPECT_EQ(1000, batch->size());
+ break;
+ case 2:
+ EXPECT_EQ(100, batch->size());
+ break;
+ case 3:
+ EXPECT_EQ(80, batch->size());
+ break;
+ default:
+ EXPECT_TRUE(false) << "Should only have 3 batches";
+ }
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ // Make sure batch processing thread has gone to sleep.
+ Env::Default()->SleepForMicroseconds(1000);
+ SerialDeviceBatchScheduler<FakeTask>::QueueOptions queue_options;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue1;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue2;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue3;
+ queue_options.max_batch_size = 1000;
+ TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue1));
+ queue_options.max_batch_size = 1000;
+ TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue2));
+ queue_options.max_batch_size = 100;
+ TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue3));
+
+ TF_ASSERT_OK(ScheduleTask(100, queue1.get()));
+ // First batch - creation time: 0, fullness: 0.1, sched score: -1
+ env.AdvanceByMicroseconds(3);
+ TF_ASSERT_OK(ScheduleTask(1000, queue2.get()));
+ // Second batch - creation time: 3, fullness: 1, sched score: -7
+ env.AdvanceByMicroseconds(5);
+ TF_ASSERT_OK(ScheduleTask(80, queue3.get()));
+ // Third batch - creation time: 8, fullness: .8, sched score: 0
+ // Release the batch processing thread.
+ env.AdvanceByMicroseconds(1000);
+ start_teardown.Notify();
+ }
+ stop_teardown.Notify();
+}
+
+TEST(SerialDeviceBatchSchedulerTest, DeleteQueue) {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1000;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ Notification finish_processing;
+ auto queue_callback = [&mu, &processed_batches, &finish_processing](
+ std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ finish_processing.WaitForNotification();
+ mu.lock();
+ processed_batches++;
+ mu.unlock();
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
+
+ // Enqueue 2 tasks, should result in 2 batches.
+ for (int i = 0; i < 2; i++) {
+ TF_ASSERT_OK(ScheduleTask(800, queue.get()));
+ }
+ std::unique_ptr<Thread> queue_deleter(Env::Default()->StartThread(
+ {}, "QueueDeleterThread", [&queue, &mu, &processed_batches] {
+ // Delete queue, should be kept alive until empty.
+ queue.reset();
+ mutex_lock l(mu);
+ EXPECT_EQ(processed_batches, 2);
+ }));
+ // Give queue_deleter thread time to delete queue.
+ Env::Default()->SleepForMicroseconds(1000);
+ finish_processing.Notify();
+}
+
+TEST(SerialDeviceBatchSchedulerTest, DeleteScheduler) {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1000;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ Notification start_processing;
+ Notification finish_processing;
+ auto queue_callback =
+ [&mu, &processed_batches, &start_processing,
+ &finish_processing](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ start_processing.WaitForNotification();
+ mutex_lock l(mu);
+ processed_batches++;
+ if (processed_batches == 2) {
+ finish_processing.Notify();
+ }
+ };
+
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
+
+ // Enqueue 2 tasks, should result in 2 batches.
+ for (int i = 0; i < 2; i++) {
+ TF_ASSERT_OK(ScheduleTask(800, queue.get()));
+ }
+ // Delete scheduler, should be kept alive until queues are empty.
+ scheduler.reset();
+ start_processing.Notify();
+ finish_processing.WaitForNotification();
+}
+
+TEST(SerialDeviceBatchSchedulerTest, QueueCapacityInfo) {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1000;
+ options.full_batch_scheduling_boost_micros = 1000;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ Notification finish_processing;
+ auto queue_callback = [&mu, &processed_batches, &finish_processing](
+ std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ mu.lock();
+ int batch_num = ++processed_batches;
+ mu.unlock();
+ if (batch_num == 1) {
+ finish_processing.WaitForNotification();
+ }
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue1;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue2;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue1));
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue2));
+
+ // Blocker task, should schedule first.
+ TF_ASSERT_OK(ScheduleTask(800, queue1.get()));
+ TF_ASSERT_OK(ScheduleTask(100, queue2.get()));
+
+ EXPECT_EQ(queue2->NumEnqueuedTasks(), 1);
+ EXPECT_EQ(queue2->SchedulingCapacity(), 9 * 1000 + 900);
+ // Enqueue 2 more tasks, should fall in same batch.
+ TF_ASSERT_OK(ScheduleTask(100, queue2.get()));
+ TF_ASSERT_OK(ScheduleTask(200, queue2.get()));
+ EXPECT_EQ(queue2->NumEnqueuedTasks(), 3);
+ EXPECT_EQ(queue2->SchedulingCapacity(), 9 * 1000 + 600);
+ // Enqueue 1 more task, should create new batch.
+ TF_ASSERT_OK(ScheduleTask(700, queue2.get()));
+ EXPECT_EQ(queue2->NumEnqueuedTasks(), 4);
+ EXPECT_EQ(queue2->SchedulingCapacity(), 8 * 1000 + 300);
+ finish_processing.Notify();
+}
+} // namespace anonymous
+} // namespace serving
+} // namespace tensorflow