aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-01-12 14:46:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-12 14:53:40 -0800
commite70a88a224f2eacf3fb0aa7348c499efe1246eac (patch)
tree258f0a791d90a0e992b9d3f61d3618c917be2f0b /tensorflow/contrib/batching
parent8297cb6f5af797398578f05d1660608ad9b6161c (diff)
Reorganize the shared batching scheduler headers, leaving forwarding shims.
PiperOrigin-RevId: 181795909
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/BUILD85
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h640
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc549
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler.h249
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler_benchmark.cc435
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler_test.cc91
-rw-r--r--tensorflow/contrib/batching/batch_scheduler.h262
-rw-r--r--tensorflow/contrib/batching/batch_scheduler_test.cc118
-rw-r--r--tensorflow/contrib/batching/shared_batch_scheduler.h686
-rw-r--r--tensorflow/contrib/batching/shared_batch_scheduler_test.cc597
-rw-r--r--tensorflow/contrib/batching/test_util/BUILD4
-rw-r--r--tensorflow/contrib/batching/test_util/fake_clock_env.cc90
-rw-r--r--tensorflow/contrib/batching/test_util/fake_clock_env.h57
-rw-r--r--tensorflow/contrib/batching/util/BUILD17
-rw-r--r--tensorflow/contrib/batching/util/periodic_function.cc102
-rw-r--r--tensorflow/contrib/batching/util/periodic_function.h114
-rw-r--r--tensorflow/contrib/batching/util/periodic_function_test.cc225
17 files changed, 15 insertions, 4306 deletions
diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD
index ea8ac2c680..cd98f0e703 100644
--- a/tensorflow/contrib/batching/BUILD
+++ b/tensorflow/contrib/batching/BUILD
@@ -12,7 +12,7 @@ cc_library(
name = "batch_scheduler_hdrs",
hdrs = ["batch_scheduler.h"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
],
)
@@ -20,18 +20,7 @@ cc_library(
name = "batch_scheduler",
hdrs = ["batch_scheduler.h"],
deps = [
- "//tensorflow/core:lib",
- ],
-)
-
-tf_cc_test(
- name = "batch_scheduler_test",
- srcs = ["batch_scheduler_test.cc"],
- deps = [
- ":batch_scheduler",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
+ "//tensorflow/core/kernels/batching_util:batch_scheduler",
],
)
@@ -39,9 +28,7 @@ cc_library(
name = "shared_batch_scheduler_hdrs",
hdrs = ["shared_batch_scheduler.h"],
deps = [
- ":batch_scheduler_hdrs",
- "//tensorflow/contrib/batching/util:periodic_function_dynamic",
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs",
],
)
@@ -49,49 +36,16 @@ cc_library(
name = "shared_batch_scheduler",
hdrs = ["shared_batch_scheduler.h"],
deps = [
- ":batch_scheduler",
- "//tensorflow/contrib/batching/util:periodic_function_dynamic",
- "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/batching_util:shared_batch_scheduler",
],
alwayslink = 1,
)
-tf_cc_test(
- name = "shared_batch_scheduler_test",
- srcs = ["shared_batch_scheduler_test.cc"],
- deps = [
- ":shared_batch_scheduler",
- "//tensorflow/contrib/batching/test_util:fake_clock_env",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
cc_library(
name = "adaptive_shared_batch_scheduler",
hdrs = ["adaptive_shared_batch_scheduler.h"],
deps = [
- ":batch_scheduler",
- "//tensorflow/contrib/batching/util:periodic_function_dynamic",
- "//tensorflow/core:lib",
- ],
-)
-
-tf_cc_test(
- name = "adaptive_shared_batch_scheduler_test",
- srcs = ["adaptive_shared_batch_scheduler_test.cc"],
- tags = [
- "local",
- "manual",
- ],
- deps = [
- ":adaptive_shared_batch_scheduler",
- "//tensorflow/contrib/batching/test_util:fake_clock_env",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
+ "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
],
)
@@ -99,34 +53,7 @@ cc_library(
name = "basic_batch_scheduler",
hdrs = ["basic_batch_scheduler.h"],
deps = [
- ":shared_batch_scheduler",
- ],
-)
-
-tf_cc_test(
- name = "basic_batch_scheduler_test",
- srcs = ["basic_batch_scheduler_test.cc"],
- deps = [
- ":basic_batch_scheduler",
- ":batch_scheduler",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
-tf_cc_test(
- name = "basic_batch_scheduler_benchmark",
- srcs = ["basic_batch_scheduler_benchmark.cc"],
- tags = [
- "local",
- "manual",
- ],
- deps = [
- ":basic_batch_scheduler",
- "//tensorflow/core:lib",
- "//tensorflow/core:tensorflow",
- "//tensorflow/core:test",
+ "//tensorflow/core/kernels/batching_util:basic_batch_scheduler",
],
)
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
index 6773accc6f..60861f83f4 100644
--- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
+++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
@@ -16,644 +16,6 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
-#include <algorithm>
-#include <functional>
-#include <memory>
-#include <queue>
-#include <random>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/contrib/batching/batch_scheduler.h"
-#include "tensorflow/contrib/batching/util/periodic_function.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/platform/cpu_info.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace serving {
-namespace internal {
-template <typename TaskType>
-class ASBSBatch;
-
-template <typename TaskType>
-class ASBSQueue;
-} // namespace internal
-
-// EXPERIMENTAL: API MAY BE SUBJECTED TO SUDDEN CHANGES.
-//
-// Shared batch scheduler designed to minimize latency. The scheduler keeps
-// track of a number of queues (one per model or model version) which are
-// continuously enqueuing requests. The scheduler groups the requests into
-// batches which it periodically sends off for processing (see
-// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler
-// prioritizes batches by age (i.e. the batch's oldest request) irrespective of
-// queue or batch size.
-//
-// The scheduling decision currently exists in two flavors, controlled by the
-// option use_in_flight_batches_implementation. It is expected that setting this
-// option to true will give universally better results; after a period of
-// testing to confirm, the old implementation will be removed.
-//
-// If use_in_flight_batches_implementation is set to true, the scheduler
-// limits the number of batches which can be processed concurrently. If a new
-// batch is created, and the number of in flight batches is below the limit,
-// the next (i.e. oldest) batch is immediately scheduled. Similarly, when a
-// batch finishes processing, the limit is rechecked, and another batch may be
-// scheduled. To avoid the need to carefully tune the limit for workload,
-// model type, platform, etc, it is dynamically adjusted in order to provide the
-// lowest latency.
-//
-// If use_in_flight_batches_implementation is set to false, the scheduler will
-// process the oldest batch at an adjustable rate, regardless of batch size.
-// The user can provide feedback to help set this rate to achieve some goal
-// (i.e. minimize overall latency, limit cpu usage, etc). The rate (or rather,
-// the corresponding period) is adjusted each time a batch is processed, using
-// an exponentially weighted moving average to smooth noisy feedback:
-// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N
-// period *= (1 + K * emwa_feedback)
-//
-// Some potential use cases:
-// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
-// involves serial processing by a device, from a latency perspective it is
-// desirable to keep the device evenly loaded, avoiding the need to wait for
-// the device to process prior batches.
-// feedback = num_pending_on_device() - desired_pending.
-// CPU utilization - If the batch processing is cpu dominated, you can reap
-// latency gains when underutilized by increasing the processing rate, but
-// back the rate off when the load increases to avoid overload.
-// feedback = cpu_rate() - desired_cpu_rate.
-
-template <typename TaskType>
-class AdaptiveSharedBatchScheduler
- : public std::enable_shared_from_this<
- AdaptiveSharedBatchScheduler<TaskType>> {
- public:
- ~AdaptiveSharedBatchScheduler() {
- // Finish processing batches before destorying other class members.
- batch_thread_pool_.reset();
- }
-
- struct Options {
- // The name to use for the pool of batch threads.
- string thread_pool_name = {"batch_threads"};
- // Number of batch processing threads; equivalently the maximum number of
- // concurrently running batches.
- int64 num_batch_threads = port::NumSchedulableCPUs();
- // The environment to use (typically only overridden by test code).
- Env* env = Env::Default();
- // Which implementation to use (described in class comments above).
- bool use_in_flight_batches_implementation = false;
- // Initial limit for number of batches being concurrently processed.
- // Non-integer values correspond to probabilistic limits - i.e. a value of
- // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time.
- double initial_in_flight_batches_limit = 3;
- // Number of batches between adjustments of in_flight_batches_limit. Larger
- // numbers will give less noisy latency measurements, but will be less
- // responsive to changes in workload.
- int64 batches_to_average_over = 1000;
-
- // TODO(kte): remove the rate based implementation and corresponding options
- // below once testing confirms the superiority of the in flight batches
- // implementation.
- // Initial batch scheduling period in microseconds. Will be altered for
- // non-zero rate_feedback.
- double initial_scheduling_period_micros = 500;
- // Minimum batch scheduling period in microseconds. Recommend setting this
- // value greater than 0, otherwise it may take a while to recover from a
- // sustained time of negative scheduling_period_feedback (which may occur
- // under low load).
- double min_scheduling_period_micros = 100;
- // Maximum batch scheduling period in microseconds.
- double max_scheduling_period_micros = 10000;
- // Feedback function used to modify the scheduling period each time a batch
- // is scheduled. Should return values roughly O(1), with positive values
- // resulting in an increased period.
- std::function<double()> scheduling_period_feedback{[] { return 0.; }};
- // To handle potentially noisy scheduling_period_feedback, the period is
- // adjusted using an exponentially weighted moving average over the previous
- // feedback_smoothing_batches batches. Must be greater than 0.
- int64 feedback_smoothing_batches = 10;
- };
-
- // Ownership is shared between the caller of Create() and any queues created
- // via AddQueue().
- static Status Create(
- const Options& options,
- std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
-
- struct QueueOptions {
- // Maximum size of each batch.
- int max_batch_size = 1000;
- // Maximum number of enqueued (i.e. non-scheduled) batches.
- int max_enqueued_batches = 10;
- };
-
- using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
-
- // Adds queue (and its callback) to be managed by this scheduler.
- Status AddQueue(const QueueOptions& options,
- BatchProcessor process_batch_callback,
- std::unique_ptr<BatchScheduler<TaskType>>* queue);
-
- double in_flight_batches_limit() {
- mutex_lock l(mu_);
- return in_flight_batches_limit_;
- }
-
- private:
- // access to AddBatch, RemoveQueue, GetEnv.
- friend class internal::ASBSQueue<TaskType>;
-
- explicit AdaptiveSharedBatchScheduler(const Options& options);
-
- // Batch scheduling function which runs every scheduling_period_ microseconds.
- // Only used when options_.use_in_flight_batches_implementation == false.
- void ProcessOneBatch();
-
- // Tracks processing latency and adjusts in_flight_batches_limit to minimize.
- // Only used when options_.use_in_flight_batches_implementation == true.
- void CallbackWrapper(const internal::ASBSBatch<TaskType>* batch,
- BatchProcessor callback);
-
- // Schedules batch if in_flight_batches_limit_ is not met.
- // Only used when options_.use_in_flight_batches_implementation == true.
- void MaybeScheduleNextBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- // Notifies scheduler of non-empty batch which is eligible for processing.
- void AddBatch(const internal::ASBSBatch<TaskType>* batch);
-
- // Removes queue from scheduler.
- void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
-
- Env* GetEnv() const { return options_.env; }
-
- const Options options_;
-
- struct BatchCompare {
- bool operator()(const internal::ASBSBatch<TaskType>* a,
- const internal::ASBSBatch<TaskType>* b);
- };
-
- // Collection of batches added by AddBatch, ordered by age. Owned by scheduler
- // until they are released for processing.
- std::priority_queue<const internal::ASBSBatch<TaskType>*,
- std::vector<const internal::ASBSBatch<TaskType>*>,
- BatchCompare>
- batches_ GUARDED_BY(mu_);
-
- // Unowned queues and callbacks added by AddQueue.
- std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
- queues_and_callbacks_ GUARDED_BY(mu_);
-
- mutex mu_;
-
- // Responsible for running ProcessOneBatch. PeriodicFunction was used in order
- // to check for deletion so that the thread can be shut down.
- // Only used when options_.use_in_flight_batches_implementation == false.
- std::unique_ptr<PeriodicFunction> scheduling_thread_;
-
- // Responsible for running the batch processing callbacks.
- std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
-
- // Time interval in microseconds between successive ProcessOneBatch calls.
- // Only used when options_.use_in_flight_batches_implementation == false.
- double scheduling_period_;
-
- // Exponentially weighted moving average of
- // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch
- // call.
- // Only used when options_.use_in_flight_batches_implementation == false.
- double ewma_feedback_ = 0;
-
- // Limit on number of batches which can be concurrently processed.
- // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2
- // results in an actual cap of 3 80% of the time, and 4 20% of the time.
- // Only used when options_.use_in_flight_batches_implementation == true.
- double in_flight_batches_limit_ GUARDED_BY(mu_);
-
- // Number of batches currently being processed.
- // Only used when options_.use_in_flight_batches_implementation == true.
- int64 in_flight_batches_ GUARDED_BY(mu_) = 0;
-
- // RNG engine and distribution.
- // Only used when options_.use_in_flight_batches_implementation == true.
- std::default_random_engine rand_engine_;
- std::uniform_real_distribution<double> rand_double_;
-
- // Fields controlling the dynamic adjustment of in_flight_batches_limit_.
- // Only used when options_.use_in_flight_batches_implementation == true.
- // Number of batches since the last in_flight_batches_limit_ adjustment.
- int64 batch_count_ GUARDED_BY(mu_) = 0;
- // Sum of processing latency for batches counted by batch_count_.
- int64 batch_latency_sum_ GUARDED_BY(mu_) = 0;
- // Average batch latency for previous value of in_flight_batches_limit_.
- double last_avg_latency_ms_ GUARDED_BY(mu_) = 0;
- // Did last_avg_latency_ms_ decrease from the previous last_avg_latency_ms_?
- bool last_latency_decreased_ GUARDED_BY(mu_) = false;
- // Current direction (+-) to adjust in_flight_batches_limit_
- int step_direction_ GUARDED_BY(mu_) = 1;
- // Max adjustment size (as a fraction of in_flight_batches_limit_).
- constexpr static double kMaxStepSizeMultiplier = 0.125; // 1/8;
- // Min adjustment size (as a fraction of in_flight_batches_limit_).
- constexpr static double kMinStepSizeMultiplier = 0.0078125; // 1/128
- // Current adjustment size (as a fraction of in_flight_batches_limit_).
- double step_size_multiplier_ GUARDED_BY(mu_) = kMaxStepSizeMultiplier;
-
- TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
-};
-
-//////////////////////////////////////////////////////////
-// Implementation details follow. API users need not read.
-
-namespace internal {
-// Consolidates tasks into batches, passing them off to the
-// AdaptiveSharedBatchScheduler for processing.
-template <typename TaskType>
-class ASBSQueue : public BatchScheduler<TaskType> {
- public:
- using QueueOptions =
- typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
-
- ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
- const QueueOptions& options);
-
- ~ASBSQueue() override;
-
- // Adds task to current batch. Fails if the task size is larger than the batch
- // size or if the current batch is full and this queue's number of outstanding
- // batches is at its maximum.
- Status Schedule(std::unique_ptr<TaskType>* task) override;
-
- // Number of tasks waiting to be scheduled.
- size_t NumEnqueuedTasks() const override;
-
- // Number of size 1 tasks which could currently be scheduled without failing.
- size_t SchedulingCapacity() const override;
-
- // Notifies queue that a batch is about to be scheduled; the queue should not
- // place any more tasks in this batch.
- void ReleaseBatch(const ASBSBatch<TaskType>* batch);
-
- size_t max_task_size() const override { return options_.max_batch_size; }
-
- private:
- std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
- const QueueOptions options_;
- // Owned by scheduler_.
- ASBSBatch<TaskType>* current_batch_ GUARDED_BY(mu_) = nullptr;
- int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0;
- int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0;
- mutable mutex mu_;
- TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
-};
-
-// Batch which remembers when and by whom it was created.
-template <typename TaskType>
-class ASBSBatch : public Batch<TaskType> {
- public:
- ASBSBatch(ASBSQueue<TaskType>* queue, int64 creation_time_micros)
- : queue_(queue), creation_time_micros_(creation_time_micros) {}
-
- ~ASBSBatch() override {}
-
- ASBSQueue<TaskType>* queue() const { return queue_; }
-
- int64 creation_time_micros() const { return creation_time_micros_; }
-
- private:
- ASBSQueue<TaskType>* queue_;
- const int64 creation_time_micros_;
- TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
-};
-} // namespace internal
-
-// ---------------- AdaptiveSharedBatchScheduler ----------------
-
-template <typename TaskType>
-constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMaxStepSizeMultiplier;
-
-template <typename TaskType>
-constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMinStepSizeMultiplier;
-
-template <typename TaskType>
-Status AdaptiveSharedBatchScheduler<TaskType>::Create(
- const Options& options,
- std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
- if (options.num_batch_threads < 1) {
- return errors::InvalidArgument("num_batch_threads must be positive; was ",
- options.num_batch_threads);
- }
- if (options.min_scheduling_period_micros < 0) {
- return errors::InvalidArgument(
- "min_scheduling_period_micros must be >= 0; was ",
- options.min_scheduling_period_micros);
- }
- if (options.min_scheduling_period_micros >
- options.initial_scheduling_period_micros) {
- return errors::InvalidArgument(
- "initial_scheduling_period_micros (",
- options.initial_scheduling_period_micros,
- ") must be >= min_scheduling_period_micros (",
- options.min_scheduling_period_micros, ")");
- }
- if (options.initial_scheduling_period_micros >
- options.max_scheduling_period_micros) {
- return errors::InvalidArgument(
- "initial_scheduling_period_micros (",
- options.initial_scheduling_period_micros,
- ") must be <= max_scheduling_period_micros (",
- options.max_scheduling_period_micros, ")");
- }
- if (options.feedback_smoothing_batches < 1) {
- return errors::InvalidArgument(
- "feedback_smoothing_batches must be positive; was ",
- options.feedback_smoothing_batches);
- }
- if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
- return errors::InvalidArgument(
- "initial_in_flight_batches_limit (",
- options.initial_in_flight_batches_limit,
- ") should not be larger than num_batch_threads (",
- options.num_batch_threads, ")");
- }
- if (options.initial_in_flight_batches_limit < 1) {
- return errors::InvalidArgument(
- "initial_in_flight_batches_limit should be "
- "greater than or equal to 1; was ",
- options.initial_in_flight_batches_limit);
- }
- if (options.batches_to_average_over < 1) {
- return errors::InvalidArgument(
- "batches_to_average_over should be "
- "greater than or equal to 1; was ",
- options.batches_to_average_over);
- }
- scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
- return Status::OK();
-}
-
-template <typename TaskType>
-AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
- const Options& options)
- : options_(options),
- scheduling_period_(options.initial_scheduling_period_micros),
- in_flight_batches_limit_(options.initial_in_flight_batches_limit),
- rand_double_(0.0, 1.0) {
- std::random_device device;
- rand_engine_.seed(device());
- PeriodicFunction::Options opts;
- opts.thread_name_prefix = "scheduling_thread";
- opts.env = GetEnv();
- batch_thread_pool_.reset(new thread::ThreadPool(
- GetEnv(), options.thread_pool_name, options.num_batch_threads));
- if (!options.use_in_flight_batches_implementation) {
- scheduling_thread_.reset(
- new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts));
- }
-}
-
-template <typename TaskType>
-Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
- const QueueOptions& options, BatchProcessor process_batch_callback,
- std::unique_ptr<BatchScheduler<TaskType>>* queue) {
- if (options.max_batch_size <= 0) {
- return errors::InvalidArgument("max_batch_size must be positive; was ",
- options.max_batch_size);
- }
- if (options.max_enqueued_batches <= 0) {
- return errors::InvalidArgument(
- "max_enqueued_batches must be positive; was ",
- options.max_enqueued_batches);
- }
- internal::ASBSQueue<TaskType>* asbs_queue_raw;
- queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
- this->shared_from_this(), options));
- mutex_lock l(mu_);
- queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
- return Status::OK();
-}
-
-template <typename TaskType>
-void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
- const internal::ASBSBatch<TaskType>* batch) {
- mutex_lock l(mu_);
- batches_.push(batch);
- if (options_.use_in_flight_batches_implementation) {
- MaybeScheduleNextBatch();
- }
-}
-
-template <typename TaskType>
-void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
- const internal::ASBSQueue<TaskType>* queue) {
- mutex_lock l(mu_);
- queues_and_callbacks_.erase(queue);
-}
-
-template <typename TaskType>
-void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatch() {
- if (batches_.empty() || in_flight_batches_ >= in_flight_batches_limit_)
- return;
- // Non-integer limit handled probabilistially.
- if (in_flight_batches_limit_ - in_flight_batches_ < 1 &&
- rand_double_(rand_engine_) >
- (in_flight_batches_limit_ - in_flight_batches_))
- return;
- const internal::ASBSBatch<TaskType>* batch = batches_.top();
- batches_.pop();
- // Queue may destroy itself after ReleaseBatch is called.
- batch->queue()->ReleaseBatch(batch);
- batch_thread_pool_->Schedule(
- std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this,
- batch, queues_and_callbacks_[batch->queue()]));
- in_flight_batches_++;
-}
-
-template <typename TaskType>
-void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper(
- const internal::ASBSBatch<TaskType>* batch,
- AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback) {
- int64 start_time = batch->creation_time_micros();
- callback(std::unique_ptr<Batch<TaskType>>(
- const_cast<internal::ASBSBatch<TaskType>*>(batch)));
- int64 end_time = GetEnv()->NowMicros();
- mutex_lock l(mu_);
- in_flight_batches_--;
- batch_count_++;
- batch_latency_sum_ += end_time - start_time;
- // Occasionally adjust in_flight_batches_limit_ to minimize average latency.
- // Although the optimal value may depend on the workload, the latency should
- // be a simple convex function of in_flight_batches_limit_, allowing us to
- // locate the global minimum relatively quickly.
- if (batch_count_ == options_.batches_to_average_over) {
- double current_avg_latency_ms = (batch_latency_sum_ / 1000.) / batch_count_;
- bool current_latency_decreased =
- current_avg_latency_ms < last_avg_latency_ms_;
- if (current_latency_decreased) {
- // If latency improvement was because we're moving in the correct
- // direction, increase step_size so that we can get to the minimum faster.
- // If latency improvement was due to backtracking from a previous failure,
- // decrease step_size in order to refine our location.
- step_size_multiplier_ *= (last_latency_decreased_ ? 2 : 0.5);
- step_size_multiplier_ =
- std::min(step_size_multiplier_, kMaxStepSizeMultiplier);
- step_size_multiplier_ =
- std::max(step_size_multiplier_, kMinStepSizeMultiplier);
- } else {
- // Return (nearly) to previous position and confirm that latency is better
- // there before decreasing step size.
- step_direction_ = -step_direction_;
- }
- in_flight_batches_limit_ +=
- step_direction_ * in_flight_batches_limit_ * step_size_multiplier_;
- in_flight_batches_limit_ =
- std::min(in_flight_batches_limit_,
- static_cast<double>(options_.num_batch_threads));
- in_flight_batches_limit_ = std::max(in_flight_batches_limit_, 1.0);
- last_avg_latency_ms_ = current_avg_latency_ms;
- last_latency_decreased_ = current_latency_decreased;
- batch_count_ = 0;
- batch_latency_sum_ = 0;
- }
- MaybeScheduleNextBatch();
-}
-
-template <typename TaskType>
-void AdaptiveSharedBatchScheduler<TaskType>::ProcessOneBatch() {
- static const double kFeedbackMultiplier = .001;
- const internal::ASBSBatch<TaskType>* batch = nullptr;
- BatchProcessor callback;
- const int64 start_time_micros = GetEnv()->NowMicros();
- {
- mutex_lock l(mu_);
- if (!batches_.empty()) {
- batch = batches_.top();
- batches_.pop();
- callback = queues_and_callbacks_[batch->queue()];
- }
- }
- if (batch != nullptr) {
- double feedback = options_.scheduling_period_feedback();
- const int64 N = options_.feedback_smoothing_batches;
- ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N;
- scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_);
- if (scheduling_period_ < options_.min_scheduling_period_micros) {
- scheduling_period_ = options_.min_scheduling_period_micros;
- } else if (scheduling_period_ > options_.max_scheduling_period_micros) {
- scheduling_period_ = options_.max_scheduling_period_micros;
- }
- // Queue may destroy itself after ReleaseBatch is called.
- batch->queue()->ReleaseBatch(batch);
- batch_thread_pool_->Schedule([callback, batch] {
- callback(std::unique_ptr<Batch<TaskType>>(
- const_cast<internal::ASBSBatch<TaskType>*>(batch)));
- });
- }
- const int64 sleep_time =
- scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros);
- if (sleep_time > 0) {
- GetEnv()->SleepForMicroseconds(sleep_time);
- }
-}
-
-template <typename TaskType>
-bool AdaptiveSharedBatchScheduler<TaskType>::BatchCompare::operator()(
- const internal::ASBSBatch<TaskType>* a,
- const internal::ASBSBatch<TaskType>* b) {
- return a->creation_time_micros() > b->creation_time_micros();
-}
-
-// ---------------- ASBSQueue ----------------
-
-namespace internal {
-template <typename TaskType>
-ASBSQueue<TaskType>::ASBSQueue(
- std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
- const QueueOptions& options)
- : scheduler_(scheduler), options_(options) {}
-
-template <typename TaskType>
-ASBSQueue<TaskType>::~ASBSQueue() {
- // Wait until last batch has been scheduled.
- const int kSleepMicros = 1000;
- for (;;) {
- {
- mutex_lock l(mu_);
- if (num_enqueued_batches_ == 0) {
- break;
- }
- }
- scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
- }
- scheduler_->RemoveQueue(this);
-}
-
-template <typename TaskType>
-Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
- ASBSBatch<TaskType>* new_batch = nullptr;
- size_t size = (*task)->size();
- if (size > options_.max_batch_size) {
- return errors::InvalidArgument("Task size ", size,
- " is larger than maximum batch size ",
- options_.max_batch_size);
- }
- {
- mutex_lock l(mu_);
- // Current batch is full, create another if allowed.
- if (current_batch_ &&
- current_batch_->size() + size > options_.max_batch_size) {
- if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
- return errors::Unavailable("The batch scheduling queue is full");
- }
- current_batch_->Close();
- current_batch_ = nullptr;
- }
- if (!current_batch_) {
- num_enqueued_batches_++;
- current_batch_ = new_batch =
- new ASBSBatch<TaskType>(this, scheduler_->GetEnv()->NowMicros());
- }
- current_batch_->AddTask(std::move(*task));
- num_enqueued_tasks_++;
- }
- // AddBatch must be called outside of lock, since it may call ReleaseBatch.
- if (new_batch != nullptr) scheduler_->AddBatch(new_batch);
- return Status::OK();
-}
-
-template <typename TaskType>
-void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
- mutex_lock l(mu_);
- num_enqueued_batches_--;
- num_enqueued_tasks_ -= batch->num_tasks();
- if (batch == current_batch_) {
- current_batch_->Close();
- current_batch_ = nullptr;
- }
-}
-
-template <typename TaskType>
-size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
- mutex_lock l(mu_);
- return num_enqueued_tasks_;
-}
-
-template <typename TaskType>
-size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
- mutex_lock l(mu_);
- const int current_batch_capacity =
- current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
- const int spare_batches =
- options_.max_enqueued_batches - num_enqueued_batches_;
- return spare_batches * options_.max_batch_size + current_batch_capacity;
-}
-} // namespace internal
-} // namespace serving
-} // namespace tensorflow
+#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
deleted file mode 100644
index 68ee277327..0000000000
--- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
+++ /dev/null
@@ -1,549 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h"
-
-#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace serving {
-namespace anonymous {
-
-class FakeTask : public BatchTask {
- public:
- explicit FakeTask(size_t size) : size_(size) {}
-
- ~FakeTask() override = default;
-
- size_t size() const override { return size_; }
-
- private:
- const size_t size_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
-};
-
-// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
-// that task. Returns the resulting status.
-Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
- std::unique_ptr<FakeTask> task(new FakeTask(task_size));
- Status status = scheduler->Schedule(&task);
- // Schedule() should have consumed 'task' iff it returned Status::OK.
- CHECK_EQ(status.ok(), task == nullptr);
- return status;
-}
-
-// Creates a thread that waits on 'start' and then advances the fake clock in
-// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
-// use the clock to be destroyed.
-std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
- test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
- return std::unique_ptr<Thread>(Env::Default()->StartThread(
- {}, "FakeClockAdvancerThread", [env, start, stop] {
- start->WaitForNotification();
- while (!stop->HasBeenNotified()) {
- env->AdvanceByMicroseconds(10);
- Env::Default()->SleepForMicroseconds(10);
- }
- }));
-}
-
-TEST(AdaptiveSharedBatchSchedulerTest, Basic) {
- for (const bool delete_scheduler_early : {false, true}) {
- for (const bool delete_queue_1_early : {false, true}) {
- int queue_0_tasks = 0;
- auto queue_0_callback =
- [&queue_0_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- EXPECT_GT(batch->num_tasks(), 0);
- for (int i = 0; i < batch->num_tasks(); i++) {
- queue_0_tasks += batch->task(i).size();
- }
- };
- int queue_1_tasks = 0;
- auto queue_1_callback =
- [&queue_1_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- EXPECT_GT(batch->num_tasks(), 0);
- for (int i = 0; i < batch->num_tasks(); i++) {
- queue_1_tasks += batch->task(i).size();
- }
- };
- {
- std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(
- AdaptiveSharedBatchScheduler<FakeTask>::Create({}, &scheduler));
-
- // Create two queues.
- std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
- TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0));
- std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
- TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1));
-
- if (delete_scheduler_early) {
- // Delete our copy of the scheduler. The queues should keep it alive
- // under the covers.
- scheduler = nullptr;
- }
- // Submit tasks to the two queues, and (optionally) remove the queues.
- TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
- TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
- TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
- TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
- if (delete_queue_1_early) {
- queue_1 = nullptr;
- }
- TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
- }
- EXPECT_EQ(queue_0_tasks, 9);
- EXPECT_EQ(queue_1_tasks, 6);
- }
- }
-}
-
-TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) {
- using Scheduler = AdaptiveSharedBatchScheduler<FakeTask>;
- std::shared_ptr<Scheduler> scheduler;
- Scheduler::Options options;
- options.num_batch_threads = 0;
- EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
- options = Scheduler::Options();
- options.min_scheduling_period_micros = 50;
- options.max_scheduling_period_micros = 100;
- options.initial_scheduling_period_micros = 1;
- EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
- options = Scheduler::Options();
- options.min_scheduling_period_micros = 50;
- options.max_scheduling_period_micros = 100;
- options.initial_scheduling_period_micros = 1000;
- EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
- options = Scheduler::Options();
- options.min_scheduling_period_micros = 100;
- options.max_scheduling_period_micros = 50;
- options.initial_scheduling_period_micros = 75;
- EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
- options = Scheduler::Options();
- options.feedback_smoothing_batches = 0;
- EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
- options = Scheduler::Options();
- options.initial_in_flight_batches_limit = 0.5;
- EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
- options = Scheduler::Options();
- options.num_batch_threads = 5;
- options.initial_in_flight_batches_limit = 8;
- EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
- options = Scheduler::Options();
- options.batches_to_average_over = -5;
- EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
-}
-
-TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) {
- test_util::FakeClockEnv env(Env::Default());
- Notification start_teardown, stop_teardown;
- std::unique_ptr<Thread> teardown_thread =
- CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
- {
- AdaptiveSharedBatchScheduler<FakeTask>::Options options;
- options.initial_scheduling_period_micros = 1000;
- options.env = &env;
- std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(
- AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
- std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
- int queue_0_tasks = 0;
- int queue_1_tasks = 0;
- auto queue_0_callback = [&queue_0_tasks,
- &env](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- EXPECT_GT(batch->num_tasks(), 0);
- for (int i = 0; i < batch->num_tasks(); i++) {
- queue_0_tasks += batch->task(i).size();
- }
- env.SleepForMicroseconds(1);
- };
- auto queue_1_callback = [&queue_1_tasks,
- &env](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- EXPECT_GT(batch->num_tasks(), 0);
- for (int i = 0; i < batch->num_tasks(); i++) {
- queue_1_tasks += batch->task(i).size();
- }
- env.SleepForMicroseconds(1);
- };
- AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 10;
- queue_options.max_enqueued_batches = 0;
- // Queue must have max_enqueued_batchs > 1.
- EXPECT_FALSE(
- scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok());
- queue_options.max_enqueued_batches = 2;
- TF_ASSERT_OK(
- scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
- EXPECT_EQ(10, queue_0->max_task_size());
- queue_options.max_batch_size = 0;
- // Queue must have max_batch_size > 0.
- EXPECT_FALSE(
- scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok());
- queue_options.max_batch_size = 2;
- queue_options.max_enqueued_batches = 1;
- TF_ASSERT_OK(
- scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
-
- // Wait for scheduling_thread to sleep.
- env.BlockUntilThreadsAsleep(1);
- // Task larger than max_batch_size shouldn't schedule.
- EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok());
- TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
- TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
- env.AdvanceByMicroseconds(1);
-
- // Task larger than max_batch_size shouldn't schedule.
- EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok());
- TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
- TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
- env.AdvanceByMicroseconds(1);
- // Exceeds max_enqueued_batches, shouldn't schedule.
- EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok());
-
- TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
- // Exceeds max_enqueued_batches, shouldn't schedule.
- EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok());
- TF_ASSERT_OK(ScheduleTask(4, queue_0.get()));
-
- // Batches should be processed in order from oldest to newest.
- env.AdvanceByMicroseconds(1000);
- env.BlockUntilThreadsAsleep(2);
- EXPECT_EQ(queue_0_tasks, 10);
- EXPECT_EQ(queue_1_tasks, 0);
-
- env.AdvanceByMicroseconds(1000);
- env.BlockUntilThreadsAsleep(2);
- EXPECT_EQ(queue_0_tasks, 10);
- EXPECT_EQ(queue_1_tasks, 2);
-
- env.AdvanceByMicroseconds(1000);
- env.BlockUntilThreadsAsleep(2);
- EXPECT_EQ(queue_0_tasks, 19);
- EXPECT_EQ(queue_1_tasks, 2);
- start_teardown.Notify();
- }
- stop_teardown.Notify();
-}
-
-TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) {
- test_util::FakeClockEnv env(Env::Default());
- Notification start_teardown, stop_teardown;
- std::unique_ptr<Thread> teardown_thread =
- CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
- {
- double feedback = 0;
- AdaptiveSharedBatchScheduler<FakeTask>::Options options;
- options.initial_scheduling_period_micros = 1000;
- options.min_scheduling_period_micros = 200;
- options.max_scheduling_period_micros = 2000;
- options.env = &env;
- options.scheduling_period_feedback = [&feedback] { return feedback; };
- options.feedback_smoothing_batches = 1;
- std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(
- AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- int scheduled_items = 0;
- auto queue_callback = [&scheduled_items,
- &env](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- EXPECT_GT(batch->num_tasks(), 0);
- scheduled_items = 0;
- for (int i = 0; i < batch->num_tasks(); i++) {
- scheduled_items += batch->task(i).size();
- }
- env.SleepForMicroseconds(1);
- };
-
- TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
-
- // Wait for scheduling_thread to sleep.
- env.BlockUntilThreadsAsleep(1);
- // Enqueue 6 batches.
- for (int i = 0; i < 6; i++) {
- TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
- env.AdvanceByMicroseconds(1);
- }
- feedback = -500;
- env.AdvanceByMicroseconds(994);
- env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec.
- EXPECT_EQ(scheduled_items, 900);
- env.AdvanceByMicroseconds(500);
- env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
- EXPECT_EQ(scheduled_items, 901);
- feedback = 0;
- env.AdvanceByMicroseconds(250);
- env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
- EXPECT_EQ(scheduled_items, 902);
- feedback = 10000; // large feedback should hit max_scheduling_period.
- env.AdvanceByMicroseconds(250);
- env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec.
- EXPECT_EQ(scheduled_items, 903);
- feedback = -10000; // large feedback should hit min_scheduling_period.
- env.AdvanceByMicroseconds(1999);
- // No callback scheduled, only scheduling thread sleeping.
- env.BlockUntilThreadsAsleep(1);
- EXPECT_EQ(scheduled_items, 903);
- env.AdvanceByMicroseconds(1);
- env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec.
- EXPECT_EQ(scheduled_items, 904);
- env.AdvanceByMicroseconds(200);
- env.BlockUntilThreadsAsleep(2);
- EXPECT_EQ(scheduled_items, 905);
- start_teardown.Notify();
- }
- stop_teardown.Notify();
-}
-
-TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) {
- test_util::FakeClockEnv env(Env::Default());
- Notification start_teardown, stop_teardown;
- std::unique_ptr<Thread> teardown_thread =
- CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
- {
- double feedback = 0;
- AdaptiveSharedBatchScheduler<FakeTask>::Options options;
- options.initial_scheduling_period_micros = 1000;
- options.env = &env;
- options.scheduling_period_feedback = [&feedback] { return feedback; };
- options.feedback_smoothing_batches = 3;
- std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(
- AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- int scheduled_items = 0;
- auto queue_callback = [&scheduled_items,
- &env](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- EXPECT_GT(batch->num_tasks(), 0);
- scheduled_items = 0;
- for (int i = 0; i < batch->num_tasks(); i++) {
- scheduled_items += batch->task(i).size();
- }
- env.SleepForMicroseconds(1);
- };
-
- TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
-
- // Wait for scheduling_thread to sleep.
- env.BlockUntilThreadsAsleep(1);
- // Enqueue 4 batches.
- for (int i = 0; i < 4; i++) {
- TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
- env.AdvanceByMicroseconds(1);
- }
- feedback = -300;
- env.AdvanceByMicroseconds(996);
- env.BlockUntilThreadsAsleep(2);
- // ewma_feedback = 100, scheduling_period = 900.
- EXPECT_EQ(scheduled_items, 900);
- env.AdvanceByMicroseconds(899);
- // No callback scheduled, only scheduling thread sleeping.
- env.BlockUntilThreadsAsleep(1);
- EXPECT_EQ(scheduled_items, 900);
- env.AdvanceByMicroseconds(1);
- env.BlockUntilThreadsAsleep(2);
- // ewma_feedback = 167, scheduling_period = 750.
- EXPECT_EQ(scheduled_items, 901);
- env.AdvanceByMicroseconds(749);
- // No callback scheduled, only scheduling thread sleeping.
- env.BlockUntilThreadsAsleep(1);
- EXPECT_EQ(scheduled_items, 901);
- feedback = 1000 / 3.;
- env.AdvanceByMicroseconds(1);
- env.BlockUntilThreadsAsleep(2);
- // emwa_feedback = 0, scheduling_period = 750.
- EXPECT_EQ(scheduled_items, 902);
- env.AdvanceByMicroseconds(749);
- // No callback scheduled, only scheduling thread sleeping.
- env.BlockUntilThreadsAsleep(1);
- EXPECT_EQ(scheduled_items, 902);
- env.AdvanceByMicroseconds(1);
- env.BlockUntilThreadsAsleep(2);
- EXPECT_EQ(scheduled_items, 903);
- start_teardown.Notify();
- }
- stop_teardown.Notify();
-}
-
-TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) {
- test_util::FakeClockEnv env(Env::Default());
- Notification start_teardown, stop_teardown;
- std::unique_ptr<Thread> teardown_thread =
- CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
- {
- AdaptiveSharedBatchScheduler<FakeTask>::Options options;
- options.initial_scheduling_period_micros = 1000;
- options.env = &env;
- std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(
- AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- int scheduled_items = 0;
- auto queue_callback = [&scheduled_items,
- &env](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- EXPECT_GT(batch->num_tasks(), 0);
- scheduled_items = 0;
- for (int i = 0; i < batch->num_tasks(); i++) {
- scheduled_items += batch->task(i).size();
- }
- env.SleepForMicroseconds(1);
- };
- AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 10;
- queue_options.max_enqueued_batches = 10;
- TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue));
-
- // Wait for scheduling_thread to sleep.
- env.BlockUntilThreadsAsleep(1);
- // Enqueue 3 tasks.
- EXPECT_EQ(queue->NumEnqueuedTasks(), 0);
- EXPECT_EQ(queue->SchedulingCapacity(), 100);
- TF_ASSERT_OK(ScheduleTask(5, queue.get()));
- EXPECT_EQ(queue->NumEnqueuedTasks(), 1);
- EXPECT_EQ(queue->SchedulingCapacity(), 95);
- env.AdvanceByMicroseconds(1);
- TF_ASSERT_OK(ScheduleTask(6, queue.get()));
- EXPECT_EQ(queue->NumEnqueuedTasks(), 2);
- EXPECT_EQ(queue->SchedulingCapacity(), 84);
- env.AdvanceByMicroseconds(1);
- TF_ASSERT_OK(ScheduleTask(1, queue.get()));
- EXPECT_EQ(queue->NumEnqueuedTasks(), 3);
- EXPECT_EQ(queue->SchedulingCapacity(), 83);
-
- env.AdvanceByMicroseconds(998);
- env.BlockUntilThreadsAsleep(2);
- EXPECT_EQ(scheduled_items, 5);
- env.AdvanceByMicroseconds(1000);
- env.BlockUntilThreadsAsleep(2);
- EXPECT_EQ(scheduled_items, 7);
- start_teardown.Notify();
- }
- stop_teardown.Notify();
-}
-
-TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesImplementation) {
- AdaptiveSharedBatchScheduler<FakeTask>::Options options;
- options.use_in_flight_batches_implementation = true;
- options.initial_in_flight_batches_limit = 2;
- options.batches_to_average_over = 1000;
- mutex mu;
- int processed_batches = 0;
- Notification finish_processing;
- auto queue_callback = [&mu, &processed_batches, &finish_processing](
- std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- EXPECT_GT(batch->num_tasks(), 0);
- mu.lock();
- int batch_num = ++processed_batches;
- mu.unlock();
- if (batch_num == 2) {
- // Give third batch a chance to process if it's going to.
- Env::Default()->SleepForMicroseconds(1000);
- finish_processing.Notify();
- }
- if (batch_num == 3) {
- ASSERT_TRUE(finish_processing.HasBeenNotified());
- }
- finish_processing.WaitForNotification();
- };
- std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(
- AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
-
- // Enqueue 3 batches.
- for (int i = 0; i < 3; i++) {
- TF_ASSERT_OK(ScheduleTask(100, queue.get()));
- }
-}
-
-TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesLimitTuning) {
- test_util::FakeClockEnv env(Env::Default());
- Notification start_teardown, stop_teardown;
- std::unique_ptr<Thread> teardown_thread =
- CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
- {
- AdaptiveSharedBatchScheduler<FakeTask>::Options options;
- options.env = &env;
- options.use_in_flight_batches_implementation = true;
- options.initial_in_flight_batches_limit = 2;
- options.batches_to_average_over = 1;
- auto queue_callback = [&env](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- switch (batch->size()) {
- case 0:
- env.AdvanceByMicroseconds(10);
- break;
- case 1:
- env.AdvanceByMicroseconds(15);
- break;
- case 2:
- env.AdvanceByMicroseconds(10);
- break;
- case 3:
- env.AdvanceByMicroseconds(11);
- break;
- }
- };
- std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(
- AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
-
- TF_ASSERT_OK(ScheduleTask(0, queue.get()));
- double in_flight_batches_limit = 2;
- while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) {
- }
- // Initial direction will be negative.
- EXPECT_LT(scheduler->in_flight_batches_limit(), in_flight_batches_limit);
- in_flight_batches_limit = scheduler->in_flight_batches_limit();
- TF_ASSERT_OK(ScheduleTask(1, queue.get()));
- while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) {
- }
- // Latency increased -> change direction.
- EXPECT_GT(scheduler->in_flight_batches_limit(), in_flight_batches_limit);
- in_flight_batches_limit = scheduler->in_flight_batches_limit();
- TF_ASSERT_OK(ScheduleTask(2, queue.get()));
- while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) {
- }
- // Latency decreased -> keep going in same direction.
- EXPECT_GT(scheduler->in_flight_batches_limit(), in_flight_batches_limit);
- in_flight_batches_limit = scheduler->in_flight_batches_limit();
- TF_ASSERT_OK(ScheduleTask(3, queue.get()));
- while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) {
- }
- // Latency increased -> change direction.
- EXPECT_LT(scheduler->in_flight_batches_limit(), in_flight_batches_limit);
- start_teardown.Notify();
- }
- stop_teardown.Notify();
-}
-} // namespace anonymous
-} // namespace serving
-} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h
index 91065db249..63ba8fcf45 100644
--- a/tensorflow/contrib/batching/basic_batch_scheduler.h
+++ b/tensorflow/contrib/batching/basic_batch_scheduler.h
@@ -16,253 +16,6 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
-#include <stddef.h>
-#include <cstddef>
-#include <functional>
-#include <memory>
-#include <string>
-
-#include "tensorflow/contrib/batching/shared_batch_scheduler.h"
-
-namespace tensorflow {
-namespace serving {
-
-// A BatchScheduler implementation geared toward handling a single request type
-// running on a specific set of hardware resources. A typical scenario is one in
-// which all requests invoke the same machine-learned model on one GPU.
-//
-// If there are, say, two GPUs and two models each bound to one of the GPUs, one
-// could use two BasicBatchScheduler instances to schedule the two model/GPU
-// combinations independently. If multiple models must share a given GPU or
-// other hardware resource, consider using SharedBatchScheduler instead.
-//
-//
-// PARAMETERS AND BEHAVIOR:
-//
-// BasicBatchScheduler runs a fixed pool of threads, which it uses to process
-// batches of tasks. It enforces a maximum batch size, and enqueues a bounded
-// number of tasks. If the queue is nearly empty, such that a full batch cannot
-// be formed, when a thread becomes free, it anyway schedules a batch
-// immediately if a task has been in the queue for longer than a given timeout
-// parameter. If the timeout parameter is set to 0, then the batch threads will
-// always be kept busy (unless there are zero tasks waiting to be processed).
-//
-// For online serving, it is recommended to set the maximum number of enqueued
-// batches worth of tasks equal to the number of batch threads, which allows
-// enqueuing of enough tasks s.t. if every thread becomes available it can be
-// kept busy, but no more. For bulk processing jobs and throughput-oriented
-// benchmarks, you may want to set it much higher.
-//
-// When Schedule() is called, if the queue is full the call will fail with an
-// UNAVAILABLE error (after which the client may retry again later). If the call
-// succeeds, the maximum time the task will spend in the queue before being
-// placed in a batch and assigned to a thread for processing, is the greater of:
-// - the maximum time to process ceil(max_enqueued_batches/num_batch_threads)
-// (1 in the recommended configuration) batches of previously-submitted tasks
-// - the configured timeout parameter (which can be 0, as mentioned above)
-//
-// Unlike StreamingBatchScheduler, when BasicBatchScheduler assigns a batch to a
-// thread, it closes the batch. The process-batch callback may assume that every
-// batch it receives is closed at the outset.
-//
-//
-// RECOMMENDED USE-CASES:
-//
-// BasicBatchScheduler is suitable for use-cases that feature a single kind of
-// request (e.g. a server performing inference with a single machine-learned
-// model, possibly evolving over time), with loose versioning semantics.
-// Concretely, the following conditions should hold:
-//
-// A. All requests batched onto a given resource (e.g. a hardware accelerator,
-// or a pool accelerators) are of the same type. For example, they all
-// invoke the same machine-learned model.
-//
-// These variations are permitted:
-// - The model may reside in a single servable, or it may be spread across
-// multiple servables that are used in unison (e.g. a vocabulary lookup
-// table servable and a tensorflow session servable).
-// - The model's servable(s) may be static, or they may evolve over time
-// (successive servable versions).
-// - Zero or more of the servables are used in the request thread; the rest
-// are used in the batch thread. In our running example, the vocabulary
-// lookups and tensorflow runs may both be performed in the batch thread,
-// or alternatively the vocabulary lookup may occur in the request thread
-// with only the tensorflow run performed in the batch thread.
-//
-// In contrast, BasicBatchScheduler is not a good fit if the server
-// hosts multiple distinct models running on a pool accelerators, with each
-// request specifying which model it wants to use. BasicBatchScheduler
-// has no facility to time-multiplex the batch threads across multiple
-// models in a principled way. More basically, it cannot ensure that a given
-// batch doesn't contain a mixture of requests for different models.
-//
-// B. Requests do not specify a particular version of the servable(s) that must
-// be used. Instead, each request is content to use the "latest" version.
-//
-// BasicBatchScheduler does not constrain which requests get grouped
-// together into a batch, so using this scheduler there is no way to achieve
-// cohesion of versioned requests to version-specific batches.
-//
-// C. No servable version coordination needs to be performed between the
-// request threads and the batch threads. Often, servables are only used in
-// the batch threads, in which case this condition trivially holds. If
-// servables are used in both threads, then the use-case must tolerate
-// version skew across the servables used in the two kinds of threads.
-//
-//
-// EXAMPLE USE-CASE FLOW:
-//
-// For such use-cases, request processing via BasicBatchScheduler generally
-// follows this flow (given for illustration; variations are possible):
-// 1. Optionally perform some pre-processing on each request in the request
-// threads.
-// 2. Route the requests to the batch scheduler, as batching::Task objects.
-// (Since all requests are of the same type and are not versioned, the
-// scheduler is free to group them into batches arbitrarily.)
-// 3. Merge the requests into a single batched representation B.
-// 4. Obtain handles to the servable(s) needed to process B. The simplest
-// approach is to obtain the latest version of each servable. Alternatively,
-// if cross-servable consistency is required (e.g. the vocabulary lookup
-// table's version number must match that of the tensorflow session),
-// identify an appropriate version number and obtain the servable handles
-// accordingly.
-// 5. Process B using the obtained servable handles, and split the result into
-// individual per-request units.
-// 6. Perform any post-processing in the batch thread and/or request thread.
-//
-//
-// PERFORMANCE TUNING: See README.md.
-//
-template <typename TaskType>
-class BasicBatchScheduler : public BatchScheduler<TaskType> {
- public:
- // TODO(b/25089730): Tune defaults based on best practices as they develop.
- // (Keep them mirrored to the ones in SharedBatchScheduler::QueueOptions and
- // SharedBatchScheduler::Options.)
- struct Options {
- // The maximum size of each batch.
- //
- // The scheduler may form batches of any size between 1 and this number
- // (inclusive). If there is a need to quantize the batch sizes, i.e. only
- // submit batches whose size is in a small set of allowed sizes, that can be
- // done by adding padding in the process-batch callback.
- int max_batch_size = 1000;
-
- // If a task has been enqueued for this amount of time (in microseconds),
- // and a thread is available, the scheduler will immediately form a batch
- // from enqueued tasks and assign the batch to the thread for processing,
- // even if the batch's size is below 'max_batch_size'.
- //
- // This parameter offers a way to bound queue latency, so that a task isn't
- // stuck in the queue indefinitely waiting for enough tasks to arrive to
- // make a full batch. (The latency bound is given in the class documentation
- // above.)
- //
- // The goal is to smooth out batch sizes under low request rates, and thus
- // avoid latency spikes.
- int64 batch_timeout_micros = 0;
-
- // The name to use for the pool of batch threads.
- string thread_pool_name = {"batch_threads"};
-
- // The number of threads to use to process batches.
- // Must be >= 1, and should be tuned carefully.
- int num_batch_threads = port::NumSchedulableCPUs();
-
- // The maximum allowable number of enqueued (accepted by Schedule() but
- // not yet being processed on a batch thread) tasks in terms of batches.
- // If this limit is reached, Schedule() will return an UNAVAILABLE error.
- // See the class documentation above for guidelines on how to tune this
- // parameter.
- int max_enqueued_batches = 10;
-
- // The following options are typically only overridden by test code.
-
- // The environment to use.
- Env* env = Env::Default();
- };
- static Status Create(const Options& options,
- std::function<void(std::unique_ptr<Batch<TaskType>>)>
- process_batch_callback,
- std::unique_ptr<BasicBatchScheduler>* scheduler);
-
- ~BasicBatchScheduler() override = default;
-
- Status Schedule(std::unique_ptr<TaskType>* task) override;
- size_t NumEnqueuedTasks() const override;
- size_t SchedulingCapacity() const override;
-
- size_t max_task_size() const override {
- return shared_scheduler_queue_->max_task_size();
- }
-
- private:
- explicit BasicBatchScheduler(
- std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue);
-
- // This class is merely a thin wrapper around a SharedBatchScheduler with a
- // single queue.
- std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(BasicBatchScheduler);
-};
-
-//////////
-// Implementation details follow. API users need not read.
-
-template <typename TaskType>
-Status BasicBatchScheduler<TaskType>::Create(
- const Options& options,
- std::function<void(std::unique_ptr<Batch<TaskType>>)>
- process_batch_callback,
- std::unique_ptr<BasicBatchScheduler>* scheduler) {
- typename SharedBatchScheduler<TaskType>::Options shared_scheduler_options;
- shared_scheduler_options.thread_pool_name = options.thread_pool_name;
- shared_scheduler_options.num_batch_threads = options.num_batch_threads;
- shared_scheduler_options.env = options.env;
- std::shared_ptr<SharedBatchScheduler<TaskType>> shared_scheduler;
- TF_RETURN_IF_ERROR(SharedBatchScheduler<TaskType>::Create(
- shared_scheduler_options, &shared_scheduler));
-
- typename SharedBatchScheduler<TaskType>::QueueOptions
- shared_scheduler_queue_options;
- shared_scheduler_queue_options.max_batch_size = options.max_batch_size;
- shared_scheduler_queue_options.batch_timeout_micros =
- options.batch_timeout_micros;
- shared_scheduler_queue_options.max_enqueued_batches =
- options.max_enqueued_batches;
- std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue;
- TF_RETURN_IF_ERROR(shared_scheduler->AddQueue(shared_scheduler_queue_options,
- process_batch_callback,
- &shared_scheduler_queue));
-
- scheduler->reset(
- new BasicBatchScheduler<TaskType>(std::move(shared_scheduler_queue)));
- return Status::OK();
-}
-
-template <typename TaskType>
-Status BasicBatchScheduler<TaskType>::Schedule(
- std::unique_ptr<TaskType>* task) {
- return shared_scheduler_queue_->Schedule(task);
-}
-
-template <typename TaskType>
-size_t BasicBatchScheduler<TaskType>::NumEnqueuedTasks() const {
- return shared_scheduler_queue_->NumEnqueuedTasks();
-}
-
-template <typename TaskType>
-size_t BasicBatchScheduler<TaskType>::SchedulingCapacity() const {
- return shared_scheduler_queue_->SchedulingCapacity();
-}
-
-template <typename TaskType>
-BasicBatchScheduler<TaskType>::BasicBatchScheduler(
- std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue)
- : shared_scheduler_queue_(std::move(shared_scheduler_queue)) {}
-
-} // namespace serving
-} // namespace tensorflow
+#include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h"
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler_benchmark.cc b/tensorflow/contrib/batching/basic_batch_scheduler_benchmark.cc
deleted file mode 100644
index ab6c810433..0000000000
--- a/tensorflow/contrib/batching/basic_batch_scheduler_benchmark.cc
+++ /dev/null
@@ -1,435 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Benchmarks for performance (throughput and latency) of BasicBatchScheduler
-// under various rates of task injection.
-
-#include "tensorflow/contrib/batching/basic_batch_scheduler.h"
-#include "tensorflow/core/lib/histogram/histogram.h"
-#include "tensorflow/core/platform/init_main.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
-
-namespace tensorflow {
-namespace serving {
-namespace {
-
-using ::tensorflow::histogram::Histogram;
-
-// An abstract class for injecting load into a system at a specific rate.
-class LoadInjector {
- public:
- virtual ~LoadInjector() = default;
-
- // Run 'injector' 'num_injection' times, with average inter-injection spacing
- // as 'average_injection_interval_micros' (in microseconds).
- virtual void InjectLoad(std::function<void()> injector, int num_injections,
- int64 average_injection_interval_micros) const = 0;
-};
-
-// A load injector that uses uniform inter-injection spacing, i.e. each pair of
-// injections is separated in time by 'average_injection_interval_micros' (as
-// best as possible).
-class UniformLoadInjector : public LoadInjector {
- public:
- UniformLoadInjector() = default;
- ~UniformLoadInjector() override = default;
-
- void InjectLoad(std::function<void()> injector, int num_injections,
- int64 average_injection_interval_micros) const override;
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(UniformLoadInjector);
-};
-
-void UniformLoadInjector::InjectLoad(
- std::function<void()> injector, const int num_injections,
- const int64 average_injection_interval_micros) const {
- int num_injections_performed = 0;
- const int64 start_time_micros = Env::Default()->NowMicros();
- while (num_injections_performed < num_injections) {
- // Inject.
- injector();
- ++num_injections_performed;
-
- // Wait until it's time for the next injection.
- const int64 next_injection_time_micros =
- start_time_micros +
- (num_injections_performed * average_injection_interval_micros);
- int64 now_micros = Env::Default()->NowMicros();
- while (now_micros < next_injection_time_micros) {
- const int64 kSleepThresholdMicros = 1000;
- if (next_injection_time_micros - now_micros >= kSleepThresholdMicros) {
- Env::Default()->SleepForMicroseconds(1 /* minimum time */);
- }
- now_micros = Env::Default()->NowMicros();
- }
- }
-}
-
-class BenchmarkBatchTask : public BatchTask {
- public:
- BenchmarkBatchTask();
-
- BenchmarkBatchTask(const BenchmarkBatchTask&) = delete;
- BenchmarkBatchTask& operator=(const BenchmarkBatchTask&) = delete;
-
- ~BenchmarkBatchTask() override = default;
-
- size_t size() const override { return 1; }
-
- uint64 start_time_micros() const { return start_time_micros_; }
-
- private:
- // The time at which the task was created, in microseconds.
- const uint64 start_time_micros_;
-};
-
-BenchmarkBatchTask::BenchmarkBatchTask()
- : start_time_micros_(Env::Default()->NowMicros()) {}
-
-// The state and logic associated with a throughput benchmark, which injects a
-// large number of tasks into a batch scheduler and measures the total time to
-// process all the tasks.
-class ThroughputBenchmark {
- public:
- explicit ThroughputBenchmark(
- const BasicBatchScheduler<BenchmarkBatchTask>::Options&
- scheduler_options);
-
- ThroughputBenchmark(const ThroughputBenchmark&) = delete;
- ThroughputBenchmark& operator=(const ThroughputBenchmark&) = delete;
-
- // Perform the benchmark run, based on the parameters supplied to the ctor.
- void RunBenchmark(int iters);
-
- private:
- // Resets all mutable state, including the scheduler.
- void ResetState();
-
- // Processes a batch of tasks. (Invoked by 'scheduler_' on one of its batch
- // threads.)
- void ProcessBatch(std::unique_ptr<Batch<BenchmarkBatchTask>> batch);
-
- // Parameters for the BasicBatchScheduler being benchmarked.
- const BasicBatchScheduler<BenchmarkBatchTask>::Options scheduler_options_;
-
- // The BasicBatchScheduler being benchmarked.
- std::unique_ptr<BasicBatchScheduler<BenchmarkBatchTask>> scheduler_;
-};
-
-ThroughputBenchmark::ThroughputBenchmark(
- const BasicBatchScheduler<BenchmarkBatchTask>::Options& scheduler_options)
- : scheduler_options_(scheduler_options) {}
-
-void ThroughputBenchmark::RunBenchmark(int iters) {
- CHECK_GE(iters, 1);
-
- testing::StopTiming();
- ResetState();
-
- // Have each iteration issue a reasonably large number of tasks, to ensure our
- // measurements reflect steady-state behavior.
- const int kNumTasksPerIteration = 100 * 1000;
-
- testing::ItemsProcessed(iters * kNumTasksPerIteration);
- testing::UseRealTime();
- testing::StartTiming();
-
- // Schedule 'num_iterations_*kNumTasksPerIteration' tasks.
- for (int i = 0; i < iters; ++i) {
- for (int j = 0; j < kNumTasksPerIteration; ++j) {
- auto task = std::unique_ptr<BenchmarkBatchTask>(new BenchmarkBatchTask);
- TF_CHECK_OK(scheduler_->Schedule(&task));
- }
- }
-
- // Wait for the scheduler to process all tasks.
- scheduler_.reset();
- testing::StopTiming();
-}
-
-void ThroughputBenchmark::ResetState() {
- auto process_batch_callback =
- [this](std::unique_ptr<Batch<BenchmarkBatchTask>> batch) {
- ProcessBatch(std::move(batch));
- };
- TF_CHECK_OK(BasicBatchScheduler<BenchmarkBatchTask>::Create(
- scheduler_options_, process_batch_callback, &scheduler_));
-}
-
-void ThroughputBenchmark::ProcessBatch(
- std::unique_ptr<Batch<BenchmarkBatchTask>> batch) {
- // No-op.
-}
-
-// The state and logic associated with a latency benchmark, which injects tasks
-// into a batch scheduler at a controlled rate and measures the distribution of
-// task completion latencies.
-//
-// Reports the measurements to std::cout (not LOG(INFO)), like the throughput
-// measurements.
-class LatencyBenchmark {
- public:
- LatencyBenchmark(
- const BasicBatchScheduler<BenchmarkBatchTask>::Options& scheduler_options,
- int64 task_injection_interval_micros, int batch_cpu_cost);
-
- LatencyBenchmark(const LatencyBenchmark&) = delete;
- LatencyBenchmark& operator=(const LatencyBenchmark&) = delete;
-
- // Perform the benchmark run, based on the parameters supplied to the ctor.
- void RunBenchmark();
-
- private:
- // Resets all mutable state, including the scheduler and latency measurements.
- void ResetState() LOCKS_EXCLUDED(mu_);
-
- // Processes a batch of tasks. (Invoked by 'scheduler_' on one of its batch
- // threads.)
- void ProcessBatch(std::unique_ptr<Batch<BenchmarkBatchTask>> batch);
-
- // Performs one batch's dummy CPU work.
- void PerformBatchCpuWork() const;
-
- // Parameters for the BasicBatchScheduler being benchmarked.
- const BasicBatchScheduler<BenchmarkBatchTask>::Options scheduler_options_;
-
- // The time interval between successively injected tasks, in microseconds.
- // A large interval corresponds to a slow rate of task injection, and vice-
- // versa.
- const int64 task_injection_interval_micros_;
-
- // The amount of work to do while processing one batch of tasks. (The cost is
- // independent of the number of tasks in the batch.)
- const int batch_cpu_cost_;
-
- // The BasicBatchScheduler being benchmarked.
- std::unique_ptr<BasicBatchScheduler<BenchmarkBatchTask>> scheduler_;
-
- mutable mutex mu_;
-
- // A histogram of the task latencies, i.e. queue time plus processing time, in
- // milliseconds.
- Histogram task_latency_millis_histogram_ GUARDED_BY(mu_);
-
- // A histogram of the batch sizes.
- Histogram batch_size_histogram_ GUARDED_BY(mu_);
-};
-
-LatencyBenchmark::LatencyBenchmark(
- const BasicBatchScheduler<BenchmarkBatchTask>::Options& scheduler_options,
- int64 task_injection_interval_micros, int batch_cpu_cost)
- : scheduler_options_(scheduler_options),
- task_injection_interval_micros_(task_injection_interval_micros),
- batch_cpu_cost_(batch_cpu_cost) {}
-
-void LatencyBenchmark::RunBenchmark() {
- ResetState();
-
- // Arrange to inject tasks at the specified rate, for a fixed total time
- // duration.
- const int kTimeDurationMicros = 100 * 1000 * 1000 /* 100 seconds */;
- const int kNumTasks = kTimeDurationMicros / task_injection_interval_micros_;
- CHECK_GE(kNumTasks, 100000)
- << "Not enough tasks to report meaningful 99.9% latency";
-
- const int64 start_time_micros = Env::Default()->NowMicros();
-
- // Inject the tasks.
- UniformLoadInjector injector;
- injector.InjectLoad(
- [this] {
- auto task = std::unique_ptr<BenchmarkBatchTask>(new BenchmarkBatchTask);
- TF_CHECK_OK(scheduler_->Schedule(&task));
- },
- kNumTasks, task_injection_interval_micros_);
-
- // Be sure we were able to more-or-less match our target injection rate.
- const int64 target_injection_time_micros =
- kNumTasks * task_injection_interval_micros_;
- const int64 actual_injection_time_micros =
- Env::Default()->NowMicros() - start_time_micros;
- if (actual_injection_time_micros > 1.1 * target_injection_time_micros) {
- LOG(FATAL) << "Unable to inject tasks at the requested rate";
- }
-
- // Wait for the scheduler to process all injected tasks.
- scheduler_.reset();
-
- // Be sure the scheduler was able to process the tasks at close to the
- // injection rate. If not, our latency measurements will be dominated by queue
- // waiting time
- const int64 actual_processing_time_micros =
- Env::Default()->NowMicros() - start_time_micros;
- if (actual_processing_time_micros > 1.01 * actual_injection_time_micros) {
- LOG(FATAL) << "Unable to keep up with task injection rate";
- }
-
- // Report benchmark measurements.
- {
- mutex_lock l(mu_);
- std::cout << "\t"
- << "99.9% latency: "
- << task_latency_millis_histogram_.Percentile(99.9) << "ms"
- << "\t"
- << "99% batch size: " << batch_size_histogram_.Percentile(99)
- << std::endl;
- }
-}
-
-void LatencyBenchmark::ResetState() {
- auto process_batch_callback =
- [this](std::unique_ptr<Batch<BenchmarkBatchTask>> batch) {
- ProcessBatch(std::move(batch));
- };
- TF_CHECK_OK(BasicBatchScheduler<BenchmarkBatchTask>::Create(
- scheduler_options_, process_batch_callback, &scheduler_));
-
- {
- mutex_lock l(mu_);
- task_latency_millis_histogram_.Clear();
- batch_size_histogram_.Clear();
- }
-}
-
-void LatencyBenchmark::ProcessBatch(
- std::unique_ptr<Batch<BenchmarkBatchTask>> batch) {
- PerformBatchCpuWork();
- const uint64 batch_completion_time = Env::Default()->NowMicros();
-
- {
- mutex_lock l(mu_);
- batch_size_histogram_.Add(batch->num_tasks());
- }
-
- for (int i = 0; i < batch->num_tasks(); ++i) {
- const BenchmarkBatchTask& task = batch->task(i);
-
- const uint64 task_latency_micros =
- batch_completion_time - task.start_time_micros();
-
- {
- mutex_lock l(mu_);
- task_latency_millis_histogram_.Add(task_latency_micros / 1000.0);
- }
- }
-}
-
-void LatencyBenchmark::PerformBatchCpuWork() const {
- int dummy = 1;
- for (int i = 0; i < batch_cpu_cost_; ++i) {
- dummy += dummy * 2;
- }
- CHECK_NE(dummy, 0);
-}
-
-static void RunThroughputBenchmark(int iters, int64 batch_timeout_micros,
- int num_batch_threads) {
- BasicBatchScheduler<BenchmarkBatchTask>::Options scheduler_options;
- const int kMaxBatchSize = 100;
- scheduler_options.max_batch_size = kMaxBatchSize;
- scheduler_options.batch_timeout_micros = batch_timeout_micros;
- scheduler_options.num_batch_threads = num_batch_threads;
- scheduler_options.max_enqueued_batches = INT_MAX; // Unbounded queue.
- ThroughputBenchmark benchmark(scheduler_options);
- benchmark.RunBenchmark(iters);
-}
-
-static void ThroughputBM_ZeroTimeout(int iters, int num_batch_threads) {
- RunThroughputBenchmark(iters, 0 /* 0 ms timeout */, num_batch_threads);
-}
-BENCHMARK(ThroughputBM_ZeroTimeout)
- ->Arg(1)
- ->Arg(2)
- ->Arg(4)
- ->Arg(8)
- ->Arg(16)
- ->Arg(32)
- ->Arg(64);
-
-static void ThroughputBM_SmallTimeout(int iters, int num_batch_threads) {
- RunThroughputBenchmark(iters, 1 * 1000 /* 1 ms timeout */, num_batch_threads);
-}
-BENCHMARK(ThroughputBM_SmallTimeout)
- ->Arg(1)
- ->Arg(2)
- ->Arg(4)
- ->Arg(8)
- ->Arg(16)
- ->Arg(32)
- ->Arg(64);
-
-static void ThroughputBM_LargeTimeout(int iters, int num_batch_threads) {
- RunThroughputBenchmark(iters, 50 * 1000 /* 50 ms timeout */,
- num_batch_threads);
-}
-BENCHMARK(ThroughputBM_LargeTimeout)
- ->Arg(1)
- ->Arg(2)
- ->Arg(4)
- ->Arg(8)
- ->Arg(16)
- ->Arg(32)
- ->Arg(64);
-
-static void RunLatencyBenchmark(int64 task_injection_interval_micros,
- int64 batch_timeout_micros) {
- BasicBatchScheduler<BenchmarkBatchTask>::Options scheduler_options;
- const int kMaxBatchSize = 100;
- scheduler_options.max_batch_size = kMaxBatchSize;
- scheduler_options.batch_timeout_micros = batch_timeout_micros;
- const int kNumBatchThreads = 2;
- scheduler_options.num_batch_threads = kNumBatchThreads;
- scheduler_options.max_enqueued_batches = INT_MAX; // Unbounded queue.
- const int kBatchCpuCost = 10 * 1000 * 1000;
- LatencyBenchmark benchmark(scheduler_options, task_injection_interval_micros,
- kBatchCpuCost);
- benchmark.RunBenchmark();
-}
-
-static void RunLatencyBenchmarks() {
- for (const int64 batch_timeout_micros : {0, 1 * 1000, 2 * 1000, 5 * 1000}) {
- for (const int64 task_injection_interval_micros : {1000, 50, 20}) {
- std::cout << "Latency benchmark w/ batch timeout "
- << batch_timeout_micros / 1000.0 << "ms"
- << "; "
- << "task injection rate "
- << 1000000.0 / task_injection_interval_micros << "/sec"
- << "\t...";
- RunLatencyBenchmark(task_injection_interval_micros, batch_timeout_micros);
- }
- std::cout << std::endl;
- }
-}
-
-} // namespace
-} // namespace serving
-} // namespace tensorflow
-
-int main(int argc, char** argv) {
- tensorflow::port::InitMain(argv[0], &argc, &argv);
- std::setprecision(5);
-
- // Run latency benchmarks (outside of tensorflow benchmark framework).
- tensorflow::serving::RunLatencyBenchmarks();
-
- // Run throughput benchmarks (via tensorflow benchmark framework).
- tensorflow::testing::RunBenchmarks();
-
- return 0;
-}
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc
deleted file mode 100644
index 187823151c..0000000000
--- a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc
+++ /dev/null
@@ -1,91 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/batching/basic_batch_scheduler.h"
-
-#include <utility>
-
-#include "tensorflow/contrib/batching/batch_scheduler.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace serving {
-namespace {
-
-class FakeTask : public BatchTask {
- public:
- explicit FakeTask(size_t size) : size_(size) {}
-
- ~FakeTask() override = default;
-
- size_t size() const override { return size_; }
-
- private:
- const size_t size_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
-};
-
-// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()'
-// on that task. Returns the resulting status.
-Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
- std::unique_ptr<FakeTask> task(new FakeTask(task_size));
- Status status = scheduler->Schedule(&task);
- // Schedule() should have consumed 'task' iff it returned Status::OK.
- CHECK_EQ(status.ok(), task == nullptr);
- return status;
-}
-
-// Since BasicBatchScheduler is implemented as a thin wrapper around
-// SharedBatchScheduler, we only do some basic testing. More comprehensive
-// testing is done in shared_batch_scheduler_test.cc.
-
-TEST(BasicBatchSchedulerTest, Basic) {
- bool callback_called = false;
- auto callback = [&callback_called](std::unique_ptr<Batch<FakeTask>> batch) {
- callback_called = true;
- ASSERT_TRUE(batch->IsClosed());
- ASSERT_EQ(2, batch->num_tasks());
- EXPECT_EQ(3, batch->task(0).size());
- EXPECT_EQ(5, batch->task(1).size());
- };
- {
- BasicBatchScheduler<FakeTask>::Options options;
- options.max_batch_size = 10;
- options.batch_timeout_micros = 100 * 1000; // 100 milliseconds
- options.num_batch_threads = 1;
- options.max_enqueued_batches = 3;
- std::unique_ptr<BasicBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(
- BasicBatchScheduler<FakeTask>::Create(options, callback, &scheduler));
- EXPECT_EQ(10, scheduler->max_task_size());
- EXPECT_EQ(0, scheduler->NumEnqueuedTasks());
- EXPECT_EQ(3 * 10, scheduler->SchedulingCapacity());
- TF_ASSERT_OK(ScheduleTask(3, scheduler.get()));
- EXPECT_EQ(1, scheduler->NumEnqueuedTasks());
- EXPECT_EQ((3 * 10) - 3, scheduler->SchedulingCapacity());
- TF_ASSERT_OK(ScheduleTask(5, scheduler.get()));
- EXPECT_EQ(2, scheduler->NumEnqueuedTasks());
- EXPECT_EQ((3 * 10) - (3 + 5), scheduler->SchedulingCapacity());
- }
- EXPECT_TRUE(callback_called);
-}
-
-} // namespace
-} // namespace serving
-} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h
index aa8891ab4e..3afce2761f 100644
--- a/tensorflow/contrib/batching/batch_scheduler.h
+++ b/tensorflow/contrib/batching/batch_scheduler.h
@@ -13,269 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Abstractions for processing small tasks in a batched fashion, to reduce
-// processing times and costs that can be amortized across multiple tasks.
-//
-// The core class is BatchScheduler, which groups tasks into batches.
-//
-// BatchScheduler encapsulates logic for aggregating multiple tasks into a
-// batch, and kicking off processing of a batch on a thread pool it manages.
-//
-// This file defines an abstract BatchScheduler class.
-
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
-#include <stddef.h>
-#include <algorithm>
-#include <functional>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace serving {
-
-// The abstract superclass for a unit of work to be done as part of a batch.
-//
-// An implementing subclass typically contains (or points to):
-// (a) input data;
-// (b) a thread-safe completion signal (e.g. a Notification);
-// (c) a place to store the outcome (success, or some error), upon completion;
-// (d) a place to store the output data, upon success.
-//
-// Items (b), (c) and (d) are typically non-owned pointers to data homed
-// elsewhere, because a task's ownership gets transferred to a BatchScheduler
-// (see below) and it may be deleted as soon as it is done executing.
-class BatchTask {
- public:
- virtual ~BatchTask() = default;
-
- // Returns the size of the task, in terms of how much it contributes to the
- // size of a batch. (A batch's size is the sum of its task sizes.)
- virtual size_t size() const = 0;
-};
-
-// A thread-safe collection of BatchTasks, to be executed together in some
-// fashion.
-//
-// At a given time, a batch is either "open" or "closed": an open batch can
-// accept new tasks; a closed one cannot. A batch is monotonic: initially it is
-// open and tasks can be added to it; then it is closed and its set of tasks
-// remains fixed for the remainder of its life. A closed batch cannot be re-
-// opened. Tasks can never be removed from a batch.
-//
-// Type parameter TaskType must be a subclass of BatchTask.
-template <typename TaskType>
-class Batch {
- public:
- Batch() = default;
- virtual ~Batch(); // Blocks until the batch is closed.
-
- // Appends 'task' to the batch. After calling AddTask(), the newly-added task
- // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).
- // Dies if the batch is closed.
- void AddTask(std::unique_ptr<TaskType> task);
-
- // Removes the most recently added task. Returns nullptr if the batch is
- // empty.
- std::unique_ptr<TaskType> RemoveTask();
-
- // Returns the number of tasks in the batch.
- int num_tasks() const;
-
- // Returns true iff the batch contains 0 tasks.
- bool empty() const;
-
- // Returns a reference to the ith task (in terms of insertion order).
- const TaskType& task(int i) const;
-
- // Returns a pointer to the ith task (in terms of insertion order).
- TaskType* mutable_task(int i);
-
- // Returns the sum of the task sizes.
- size_t size() const;
-
- // Returns true iff the batch is currently closed.
- bool IsClosed() const;
-
- // Blocks until the batch is closed.
- void WaitUntilClosed() const;
-
- // Marks the batch as closed. Dies if called more than once.
- void Close();
-
- private:
- mutable mutex mu_;
-
- // The tasks in the batch.
- std::vector<std::unique_ptr<TaskType>> tasks_ GUARDED_BY(mu_);
-
- // The sum of the sizes of the tasks in 'tasks_'.
- size_t size_ GUARDED_BY(mu_) = 0;
-
- // Whether the batch has been closed.
- Notification closed_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(Batch);
-};
-
-// An abstract batch scheduler class. Collects individual tasks into batches,
-// and processes each batch on a pool of "batch threads" that it manages. The
-// actual logic for processing a batch is accomplished via a callback.
-//
-// Type parameter TaskType must be a subclass of BatchTask.
-template <typename TaskType>
-class BatchScheduler {
- public:
- virtual ~BatchScheduler() = default;
-
- // Submits a task to be processed as part of a batch.
- //
- // Ownership of '*task' is transferred to the callee iff the method returns
- // Status::OK. In that case, '*task' is left as nullptr. Otherwise, '*task' is
- // left as-is.
- //
- // If no batch processing capacity is available to process this task at the
- // present time, and any task queue maintained by the implementing subclass is
- // full, this method returns an UNAVAILABLE error code. The client may retry
- // later.
- //
- // Other problems, such as the task size being larger than the maximum batch
- // size, yield other, permanent error types.
- //
- // In all cases, this method returns "quickly" without blocking for any
- // substantial amount of time. If the method returns Status::OK, the task is
- // processed asynchronously, and any errors that occur during the processing
- // of the batch that includes the task can be reported to 'task'.
- virtual Status Schedule(std::unique_ptr<TaskType>* task) = 0;
-
- // Returns the number of tasks that have been scheduled (i.e. accepted by
- // Schedule()), but have yet to be handed to a thread for execution as part of
- // a batch. Note that this returns the number of tasks, not the aggregate task
- // size (so if there is one task of size 3 and one task of size 5, this method
- // returns 2 rather than 8).
- virtual size_t NumEnqueuedTasks() const = 0;
-
- // Returns a guaranteed number of size 1 tasks that can be Schedule()d without
- // getting an UNAVAILABLE error. In a typical implementation, returns the
- // available space on a queue.
- //
- // There are two important caveats:
- // 1. The guarantee does not extend to varying-size tasks due to possible
- // internal fragmentation of batches.
- // 2. The guarantee only holds in a single-thread environment or critical
- // section, i.e. if an intervening thread cannot call Schedule().
- //
- // This method is useful for monitoring, or for guaranteeing a future slot in
- // the schedule (but being mindful about the caveats listed above).
- virtual size_t SchedulingCapacity() const = 0;
-
- // Returns the maximum allowed size of tasks submitted to the scheduler. (This
- // is typically equal to a configured maximum batch size.)
- virtual size_t max_task_size() const = 0;
-};
-
-//////////
-// Implementation details follow. API users need not read.
-
-template <typename TaskType>
-Batch<TaskType>::~Batch() {
- WaitUntilClosed();
-}
-
-template <typename TaskType>
-void Batch<TaskType>::AddTask(std::unique_ptr<TaskType> task) {
- DCHECK(!IsClosed());
- {
- mutex_lock l(mu_);
- size_ += task->size();
- tasks_.push_back(std::move(task));
- }
-}
-
-template <typename TaskType>
-std::unique_ptr<TaskType> Batch<TaskType>::RemoveTask() {
- {
- mutex_lock l(mu_);
- if (tasks_.empty()) {
- return nullptr;
- }
- std::unique_ptr<TaskType> task = std::move(tasks_.back());
- size_ -= task->size();
- tasks_.pop_back();
- return task;
- }
-}
-
-template <typename TaskType>
-int Batch<TaskType>::num_tasks() const {
- {
- mutex_lock l(mu_);
- return tasks_.size();
- }
-}
-
-template <typename TaskType>
-bool Batch<TaskType>::empty() const {
- {
- mutex_lock l(mu_);
- return tasks_.empty();
- }
-}
-
-template <typename TaskType>
-const TaskType& Batch<TaskType>::task(int i) const {
- DCHECK_GE(i, 0);
- {
- mutex_lock l(mu_);
- DCHECK_LT(i, tasks_.size());
- return *tasks_[i].get();
- }
-}
-
-template <typename TaskType>
-TaskType* Batch<TaskType>::mutable_task(int i) {
- DCHECK_GE(i, 0);
- {
- mutex_lock l(mu_);
- DCHECK_LT(i, tasks_.size());
- return tasks_[i].get();
- }
-}
-
-template <typename TaskType>
-size_t Batch<TaskType>::size() const {
- {
- mutex_lock l(mu_);
- return size_;
- }
-}
-
-template <typename TaskType>
-bool Batch<TaskType>::IsClosed() const {
- return const_cast<Notification*>(&closed_)->HasBeenNotified();
-}
-
-template <typename TaskType>
-void Batch<TaskType>::WaitUntilClosed() const {
- const_cast<Notification*>(&closed_)->WaitForNotification();
-}
-
-template <typename TaskType>
-void Batch<TaskType>::Close() {
- closed_.Notify();
-}
-
-} // namespace serving
-} // namespace tensorflow
+#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/batch_scheduler_test.cc b/tensorflow/contrib/batching/batch_scheduler_test.cc
deleted file mode 100644
index b627fee972..0000000000
--- a/tensorflow/contrib/batching/batch_scheduler_test.cc
+++ /dev/null
@@ -1,118 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/batching/batch_scheduler.h"
-
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace serving {
-namespace {
-
-class FakeTask : public BatchTask {
- public:
- explicit FakeTask(size_t size) : size_(size) {}
-
- ~FakeTask() override = default;
-
- size_t size() const override { return size_; }
-
- private:
- const size_t size_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
-};
-
-TEST(BatchTest, Basic) {
- Batch<FakeTask> batch;
-
- EXPECT_EQ(0, batch.num_tasks());
- EXPECT_TRUE(batch.empty());
- EXPECT_EQ(0, batch.size());
- EXPECT_FALSE(batch.IsClosed());
-
- auto task0 = new FakeTask(3);
- batch.AddTask(std::unique_ptr<FakeTask>(task0));
-
- EXPECT_EQ(1, batch.num_tasks());
- EXPECT_FALSE(batch.empty());
- EXPECT_EQ(task0->size(), batch.size());
- EXPECT_EQ(task0->size(), batch.task(0).size());
- EXPECT_FALSE(batch.IsClosed());
-
- auto task1 = new FakeTask(7);
- batch.AddTask(std::unique_ptr<FakeTask>(task1));
-
- EXPECT_EQ(2, batch.num_tasks());
- EXPECT_FALSE(batch.empty());
- EXPECT_EQ(task0->size() + task1->size(), batch.size());
- EXPECT_EQ(task1->size(), batch.task(1).size());
- EXPECT_EQ(task1->size(), batch.mutable_task(1)->size());
- EXPECT_FALSE(batch.IsClosed());
-
- batch.Close();
- EXPECT_TRUE(batch.IsClosed());
-
- EXPECT_EQ(2, batch.num_tasks());
- EXPECT_FALSE(batch.empty());
- EXPECT_EQ(task0->size() + task1->size(), batch.size());
- EXPECT_EQ(task0->size(), batch.task(0).size());
- EXPECT_EQ(task1->size(), batch.task(1).size());
-
- EXPECT_EQ(7, batch.RemoveTask()->size());
- EXPECT_EQ(3, batch.size());
- EXPECT_EQ(3, batch.RemoveTask()->size());
- EXPECT_EQ(0, batch.size());
- EXPECT_TRUE(batch.empty());
-}
-
-TEST(BatchTest, WaitUntilClosed) {
- Batch<FakeTask> batch;
- batch.AddTask(std::unique_ptr<FakeTask>(new FakeTask(3)));
- EXPECT_FALSE(batch.IsClosed());
-
- std::unique_ptr<Thread> close_thread(
- Env::Default()->StartThread(ThreadOptions(), "test", [&batch]() {
- Env::Default()->SleepForMicroseconds(100);
- batch.Close();
- }));
- batch.WaitUntilClosed();
- EXPECT_TRUE(batch.IsClosed());
-}
-
-TEST(BatchTest, DeletionBlocksUntilClosed) {
- Batch<FakeTask>* batch = new Batch<FakeTask>;
- batch->AddTask(std::unique_ptr<FakeTask>(new FakeTask(3)));
- EXPECT_FALSE(batch->IsClosed());
-
- Notification do_delete, deleted;
- std::unique_ptr<Thread> delete_thread(Env::Default()->StartThread(
- ThreadOptions(), "test", [&batch, &do_delete, &deleted]() {
- do_delete.WaitForNotification();
- delete batch;
- deleted.Notify();
- }));
- do_delete.Notify();
- Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
- EXPECT_FALSE(deleted.HasBeenNotified());
- batch->Close();
- deleted.WaitForNotification();
-}
-
-} // namespace
-} // namespace serving
-} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h
index 86c45bdc2e..7eb1e20c42 100644
--- a/tensorflow/contrib/batching/shared_batch_scheduler.h
+++ b/tensorflow/contrib/batching/shared_batch_scheduler.h
@@ -16,690 +16,6 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
-#include <stddef.h>
-#include <deque>
-#include <functional>
-#include <list>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/contrib/batching/batch_scheduler.h"
-#include "tensorflow/contrib/batching/util/periodic_function.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/cpu_info.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace serving {
-namespace internal {
-template <typename TaskType>
-class Queue;
-} // namespace internal
-} // namespace serving
-} // namespace tensorflow
-
-namespace tensorflow {
-namespace serving {
-
-// A batch scheduler for server instances that service multiple request types
-// (e.g. multiple machine-learned models, or multiple versions of a model served
-// concurrently), or even multiple distinct tasks for a given request. The
-// scheduler multiplexes batches of different kinds of tasks onto a fixed-size
-// thread pool (each batch contains tasks of a single type), in a carefully
-// controlled manner. A common configuration is to set the number of threads
-// equal to the number of hardware accelerator units, in which case the
-// scheduler takes care of multiplexing the task types onto the shared hardware,
-// in a manner that is both fair and efficient.
-//
-// Semantically, SharedBatchScheduler behaves like having N instances of
-// BasicBatchScheduler (see basic_batch_scheduler.h), one per task type. The
-// difference is that under the covers there is a single shared thread pool,
-// instead of N independent ones, with their sharing deliberately coordinated.
-//
-// SharedBatchScheduler does not implement the BatchScheduler API; rather, it
-// presents an abstraction of "queues", where each queue corresponds to one type
-// of task. Tasks submitted to a given queue are placed in their own batches,
-// and cannot be mixed with other tasks. Queues can be added and deleted
-// dynamically, to accommodate e.g. versions of a model being brought up and
-// down over the lifetime of a server.
-//
-// The batch thread pool round-robins through the queues, running one batch
-// from a queue and then moving to the next queue. Each queue behaves like a
-// BasicBatchScheduler instance, in the sense that it has maximum batch size and
-// timeout parameters, which govern when a batch is eligible to be processed.
-//
-// Each queue is independently configured with a maximum size (in terms of the
-// maximum number of batches worth of enqueued tasks). For online serving, it is
-// recommended that the queue sizes be configured such that the sum of the sizes
-// of the active queues roughly equal the number of batch threads. (The idea is
-// that if all threads become available at roughly the same time, there will be
-// enough enqueued work for them to take on, but no more.)
-//
-// If queue sizes are configured in the manner suggested above, the maximum time
-// a task can spend in a queue before being placed in a batch and assigned to a
-// thread for processing, is the greater of:
-// - the maximum time to process one batch of tasks from any active queue
-// - the configured timeout parameter for the task's queue (which can be 0)
-//
-// For bulk processing jobs and throughput-oriented benchmarks, you may want to
-// set the maximum queue size to a large value.
-//
-// TODO(b/26539183): Support queue servicing policies other than round-robin.
-// E.g. let each queue specify a "share" (an int >= 1), so e.g. with queues A
-// and B having shares 1 and 2 respectively, the servicing pattern is ABBABB...
-//
-//
-// PERFORMANCE TUNING: See README.md.
-//
-template <typename TaskType>
-class SharedBatchScheduler
- : public std::enable_shared_from_this<SharedBatchScheduler<TaskType>> {
- public:
- // TODO(b/25089730): Tune defaults based on best practices as they develop.
- struct Options {
- // The name to use for the pool of batch threads.
- string thread_pool_name = {"batch_threads"};
-
- // The number of threads to use to process batches.
- // Must be >= 1, and should be tuned carefully.
- int num_batch_threads = port::NumSchedulableCPUs();
-
- // The environment to use.
- // (Typically only overridden by test code.)
- Env* env = Env::Default();
- };
- // Ownership is shared between the caller of Create() and any queues created
- // via AddQueue().
- static Status Create(
- const Options& options,
- std::shared_ptr<SharedBatchScheduler<TaskType>>* scheduler);
-
- ~SharedBatchScheduler();
-
- // Adds a queue to which tasks may be submitted. The returned queue implements
- // the BatchScheduler API. Each queue has its own set of scheduling options,
- // and its own callback to process batches of tasks submitted to the queue.
- //
- // The returned queue's destructor blocks until all tasks submitted to it have
- // been processed.
- struct QueueOptions {
- // The maximum size of each batch.
- //
- // The scheduler may form batches of any size between 1 and this number
- // (inclusive). If there is a need to quantize the batch sizes, i.e. only
- // submit batches whose size is in a small set of allowed sizes, that can be
- // done by adding padding in the process-batch callback.
- int max_batch_size = 1000;
-
- // If a task has been enqueued for this amount of time (in microseconds),
- // and a thread is available, the scheduler will immediately form a batch
- // from enqueued tasks and assign the batch to the thread for processing,
- // even if the batch's size is below 'max_batch_size'.
- //
- // This parameter offers a way to bound queue latency, so that a task isn't
- // stuck in the queue indefinitely waiting for enough tasks to arrive to
- // make a full batch. (The latency bound is given in the class documentation
- // above.)
- //
- // The goal is to smooth out batch sizes under low request rates, and thus
- // avoid latency spikes.
- int64 batch_timeout_micros = 0;
-
- // The maximum allowable number of enqueued (accepted by Schedule() but
- // not yet being processed on a batch thread) tasks in terms of batches.
- // If this limit is reached, Schedule() will return an UNAVAILABLE error.
- // See the class documentation above for guidelines on how to tune this
- // parameter.
- int max_enqueued_batches = 10;
- };
- Status AddQueue(const QueueOptions& options,
- std::function<void(std::unique_ptr<Batch<TaskType>>)>
- process_batch_callback,
- std::unique_ptr<BatchScheduler<TaskType>>* queue);
-
- private:
- explicit SharedBatchScheduler(const Options& options);
-
- // The code executed in 'batch_threads_'. Obtains a batch to process from the
- // queue pointed to by 'next_queue_to_schedule_', and processes it. If that
- // queue declines to provide a batch to process, moves onto the next queue. If
- // no queues provide a batch to process, just sleeps briefly and exits.
- void ThreadLogic();
-
- const Options options_;
-
- mutex mu_;
-
- // A list of queues. (We use std::list instead of std::vector to ensure that
- // iterators are not invalidated by adding/removing elements. It also offers
- // efficient removal of elements from the middle.)
- using QueueList = std::list<std::unique_ptr<internal::Queue<TaskType>>>;
-
- // All "active" queues, i.e. ones that either:
- // - have not been removed, or
- // - have been removed but are not yet empty.
- QueueList queues_ GUARDED_BY(mu_);
-
- // An iterator over 'queues_', pointing to the queue from which the next
- // available batch thread should grab work.
- typename QueueList::iterator next_queue_to_schedule_ GUARDED_BY(mu_);
-
- // Used by idle batch threads to wait for work to enter the system. Notified
- // whenever a batch becomes schedulable.
- condition_variable schedulable_batch_cv_;
-
- // Threads that process batches obtained from the queues.
- std::vector<std::unique_ptr<PeriodicFunction>> batch_threads_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(SharedBatchScheduler);
-};
-
-//////////
-// Implementation details follow. API users need not read.
-
-namespace internal {
-
-// A task queue for SharedBatchScheduler. Accepts tasks and accumulates them
-// into batches, and dispenses those batches to be processed via a "pull"
-// interface. The queue's behavior is governed by maximum batch size, timeout
-// and maximum queue length parameters; see their documentation in
-// SharedBatchScheduler.
-//
-// The queue is implemented as a deque of batches, with these invariants:
-// - The number of batches is between 1 and 'options_.max_enqueued_batches'.
-// - The back-most batch is open; the rest are closed.
-//
-// Submitted tasks are added to the open batch. If that batch doesn't have room
-// but the queue isn't full, then that batch is closed and a new open batch is
-// started.
-//
-// Batch pull requests are handled by dequeuing the front-most batch if it is
-// closed. If the front-most batch is open (i.e. the queue contains only one
-// batch) and has reached the timeout, it is immediately closed and returned;
-// otherwise no batch is returned for the request.
-template <typename TaskType>
-class Queue {
- public:
- using ProcessBatchCallback =
- std::function<void(std::unique_ptr<Batch<TaskType>>)>;
- using SchedulableBatchCallback = std::function<void()>;
- Queue(const typename SharedBatchScheduler<TaskType>::QueueOptions& options,
- Env* env, ProcessBatchCallback process_batch_callback,
- SchedulableBatchCallback schdulable_batch_callback);
-
- // Illegal to destruct unless the queue is empty.
- ~Queue();
-
- // Submits a task to the queue, with the same semantics as
- // BatchScheduler::Schedule().
- Status Schedule(std::unique_ptr<TaskType>* task);
-
- // Returns the number of enqueued tasks, with the same semantics as
- // BatchScheduler::NumEnqueuedTasks().
- size_t NumEnqueuedTasks() const;
-
- // Returns the queue capacity, with the same semantics as
- // BatchScheduler::SchedulingCapacity().
- size_t SchedulingCapacity() const;
-
- // Returns the maximum allowed size of tasks submitted to the queue.
- size_t max_task_size() const { return options_.max_batch_size; }
-
- // Called by a thread that is ready to process a batch, to request one from
- // this queue. Either returns a batch that is ready to be processed, or
- // nullptr if the queue declines to schedule a batch at this time. If it
- // returns a batch, the batch is guaranteed to be closed.
- std::unique_ptr<Batch<TaskType>> ScheduleBatch();
-
- // Processes a batch that has been returned earlier by ScheduleBatch().
- void ProcessBatch(std::unique_ptr<Batch<TaskType>> batch);
-
- // Determines whether the queue is empty, i.e. has no tasks waiting or being
- // processed.
- bool IsEmpty() const;
-
- // Marks the queue closed, and waits until it is empty.
- void CloseAndWaitUntilEmpty();
-
- bool closed() const {
- mutex_lock l(mu_);
- return closed_;
- }
-
- private:
- // Same as IsEmpty(), but assumes the caller already holds a lock on 'mu_'.
- bool IsEmptyInternal() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- // Closes the open batch residing at the back of 'batches_', and inserts a
- // fresh open batch behind it.
- void StartNewBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- // Determines whether the open batch residing at the back of 'batches_' is
- // currently schedulable.
- bool IsOpenBatchSchedulable() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- const typename SharedBatchScheduler<TaskType>::QueueOptions options_;
-
- // The environment to use.
- Env* env_;
-
- // A callback invoked to processes a batch of work units. Always invoked from
- // a batch thread.
- ProcessBatchCallback process_batch_callback_;
-
- // A callback invoked to notify the scheduler that a new batch has become
- // schedulable.
- SchedulableBatchCallback schedulable_batch_callback_;
-
- mutable mutex mu_;
-
- // Whether this queue can accept new tasks. This variable is monotonic: it
- // starts as false, and then at some point gets set to true and remains true
- // for the duration of this object's life.
- bool closed_ GUARDED_BY(mu_) = false;
-
- // The enqueued batches. See the invariants in the class comments above.
- std::deque<std::unique_ptr<Batch<TaskType>>> batches_ GUARDED_BY(mu_);
-
- // The time at which the first task was added to the open (back-most) batch
- // in 'batches_'. Valid iff that batch contains at least one task.
- uint64 open_batch_start_time_micros_ GUARDED_BY(mu_);
-
- // Whether this queue contains a batch that is eligible to be scheduled. Used
- // to keep track of when to call 'schedulable_batch_callback_'.
- bool schedulable_batch_ GUARDED_BY(mu_) = false;
-
- // The number of batches currently being processed by batch threads.
- // Incremented in ScheduleBatch() and decremented in ProcessBatch().
- int num_batches_being_processed_ GUARDED_BY(mu_) = 0;
-
- // Used by CloseAndWaitUntilEmpty() to wait until the queue is empty, for the
- // case in which the queue is not empty when CloseAndWaitUntilEmpty() starts.
- // When ProcessBatch() dequeues the last batch and makes the queue empty, if
- // 'empty_notification_' is non-null it calls 'empty_notification_->Notify()'.
- Notification* empty_notification_ GUARDED_BY(mu_) = nullptr;
-
- TF_DISALLOW_COPY_AND_ASSIGN(Queue);
-};
-
-// A RAII-style object that points to a Queue and implements
-// the BatchScheduler API. To be handed out to clients who call AddQueue().
-template <typename TaskType>
-class QueueHandle : public BatchScheduler<TaskType> {
- public:
- QueueHandle(std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler,
- Queue<TaskType>* queue);
- ~QueueHandle() override;
-
- Status Schedule(std::unique_ptr<TaskType>* task) override;
- size_t NumEnqueuedTasks() const override;
- size_t SchedulingCapacity() const override;
-
- size_t max_task_size() const override { return queue_->max_task_size(); }
-
- private:
- // The scheduler that owns 'queue_'.
- std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler_;
-
- // The queue this handle wraps. Owned by 'scheduler_', which keeps it alive at
- // least until this class's destructor closes it.
- Queue<TaskType>* queue_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(QueueHandle);
-};
-
-} // namespace internal
-
-template <typename TaskType>
-Status SharedBatchScheduler<TaskType>::Create(
- const Options& options,
- std::shared_ptr<SharedBatchScheduler<TaskType>>* scheduler) {
- if (options.num_batch_threads < 1) {
- return errors::InvalidArgument("num_batch_threads must be positive; was ",
- options.num_batch_threads);
- }
- scheduler->reset(new SharedBatchScheduler<TaskType>(options));
- return Status::OK();
-}
-
-template <typename TaskType>
-SharedBatchScheduler<TaskType>::~SharedBatchScheduler() {
- // Wait until the batch threads finish clearing out and deleting the closed
- // queues.
- for (;;) {
- {
- mutex_lock l(mu_);
- if (queues_.empty()) {
- break;
- }
- }
- const int64 kSleepTimeMicros = 100;
- options_.env->SleepForMicroseconds(kSleepTimeMicros);
- }
- // Delete the batch threads before allowing state the threads may access (e.g.
- // 'mu_') to be deleted.
- batch_threads_.clear();
-}
-
-template <typename TaskType>
-Status SharedBatchScheduler<TaskType>::AddQueue(
- const QueueOptions& options,
- std::function<void(std::unique_ptr<Batch<TaskType>>)>
- process_batch_callback,
- std::unique_ptr<BatchScheduler<TaskType>>* queue) {
- if (options.max_batch_size <= 0) {
- return errors::InvalidArgument("max_batch_size must be positive; was ",
- options.max_batch_size);
- }
- if (options.batch_timeout_micros < 0) {
- return errors::InvalidArgument(
- "batch_timeout_micros must be non-negative; was ",
- options.batch_timeout_micros);
- }
- if (options.max_enqueued_batches < 0) {
- return errors::InvalidArgument(
- "max_enqueued_batches must be non-negative; was ",
- options.max_enqueued_batches);
- }
-
- auto schedulable_batch_callback = [this] {
- mutex_lock l(mu_);
- schedulable_batch_cv_.notify_one();
- };
- auto internal_queue =
- std::unique_ptr<internal::Queue<TaskType>>(new internal::Queue<TaskType>(
- options, options_.env, process_batch_callback,
- schedulable_batch_callback));
- auto handle = std::unique_ptr<BatchScheduler<TaskType>>(
- new internal::QueueHandle<TaskType>(this->shared_from_this(),
- internal_queue.get()));
- {
- mutex_lock l(mu_);
- queues_.push_back(std::move(internal_queue));
- if (next_queue_to_schedule_ == queues_.end()) {
- next_queue_to_schedule_ = queues_.begin();
- }
- }
- *queue = std::move(handle);
- return Status::OK();
-}
-
-template <typename TaskType>
-SharedBatchScheduler<TaskType>::SharedBatchScheduler(const Options& options)
- : options_(options), next_queue_to_schedule_(queues_.end()) {
- // Kick off the batch threads.
- PeriodicFunction::Options periodic_fn_options;
- periodic_fn_options.thread_name_prefix =
- strings::StrCat(options.thread_pool_name, "_");
- for (int i = 0; i < options.num_batch_threads; ++i) {
- std::unique_ptr<PeriodicFunction> thread(new PeriodicFunction(
- [this] { this->ThreadLogic(); },
- 0 /* function invocation interval time */, periodic_fn_options));
- batch_threads_.push_back(std::move(thread));
- }
-}
-
-template <typename TaskType>
-void SharedBatchScheduler<TaskType>::ThreadLogic() {
- // A batch to process next (or nullptr if no work to do).
- std::unique_ptr<Batch<TaskType>> batch_to_process;
- // The queue with which 'batch_to_process' is associated.
- internal::Queue<TaskType>* queue_for_batch = nullptr;
- {
- mutex_lock l(mu_);
-
- const int num_queues = queues_.size();
- for (int num_queues_tried = 0;
- batch_to_process == nullptr && num_queues_tried < num_queues;
- ++num_queues_tried) {
- DCHECK(next_queue_to_schedule_ != queues_.end());
-
- // If a closed queue responds to ScheduleBatch() with nullptr, the queue
- // will never yield any further batches so we can drop it. To avoid a
- // race, we take a snapshot of the queue's closedness state *before*
- // calling ScheduleBatch().
- const bool queue_closed = (*next_queue_to_schedule_)->closed();
-
- // Ask '*next_queue_to_schedule_' if it wants us to process a batch.
- batch_to_process = (*next_queue_to_schedule_)->ScheduleBatch();
- if (batch_to_process != nullptr) {
- queue_for_batch = next_queue_to_schedule_->get();
- }
-
- // Advance 'next_queue_to_schedule_'.
- if (queue_closed && (*next_queue_to_schedule_)->IsEmpty() &&
- batch_to_process == nullptr) {
- // We've encountered a closed queue with no work to do. Drop it.
- DCHECK_NE(queue_for_batch, next_queue_to_schedule_->get());
- next_queue_to_schedule_ = queues_.erase(next_queue_to_schedule_);
- } else {
- ++next_queue_to_schedule_;
- }
- if (next_queue_to_schedule_ == queues_.end() && !queues_.empty()) {
- // We've hit the end. Wrap to the first queue.
- next_queue_to_schedule_ = queues_.begin();
- }
- }
-
- if (batch_to_process == nullptr) {
- // We couldn't find any work to do. Wait until a new batch becomes
- // schedulable, or some time has elapsed, before checking again.
- const int64 kTimeoutMillis = 1; // The smallest accepted granule of time.
- WaitForMilliseconds(&l, &schedulable_batch_cv_, kTimeoutMillis);
- return;
- }
- }
-
- queue_for_batch->ProcessBatch(std::move(batch_to_process));
-}
-
-namespace internal {
-
-template <typename TaskType>
-Queue<TaskType>::Queue(
- const typename SharedBatchScheduler<TaskType>::QueueOptions& options,
- Env* env, ProcessBatchCallback process_batch_callback,
- SchedulableBatchCallback schedulable_batch_callback)
- : options_(options),
- env_(env),
- process_batch_callback_(process_batch_callback),
- schedulable_batch_callback_(schedulable_batch_callback) {
- // Create an initial, open batch.
- batches_.emplace_back(new Batch<TaskType>);
-}
-
-template <typename TaskType>
-Queue<TaskType>::~Queue() {
- mutex_lock l(mu_);
- DCHECK(IsEmptyInternal());
-
- // Close the (empty) open batch, so its destructor doesn't block.
- batches_.back()->Close();
-}
-
-template <typename TaskType>
-Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
- if ((*task)->size() > options_.max_batch_size) {
- return errors::InvalidArgument("Task size ", (*task)->size(),
- " is larger than maximum batch size ",
- options_.max_batch_size);
- }
-
- bool notify_of_schedulable_batch = false;
- {
- mutex_lock l(mu_);
-
- DCHECK(!closed_);
-
- if (batches_.back()->size() + (*task)->size() > options_.max_batch_size) {
- if (batches_.size() >= options_.max_enqueued_batches) {
- return errors::Unavailable(
- "The batch scheduling queue to which this task was submitted is "
- "full");
- }
- StartNewBatch();
- }
- if (batches_.back()->empty()) {
- open_batch_start_time_micros_ = env_->NowMicros();
- }
- batches_.back()->AddTask(std::move(*task));
-
- if (!schedulable_batch_) {
- if (batches_.size() > 1 || IsOpenBatchSchedulable()) {
- schedulable_batch_ = true;
- notify_of_schedulable_batch = true;
- }
- }
- }
-
- if (notify_of_schedulable_batch) {
- schedulable_batch_callback_();
- }
-
- return Status::OK();
-}
-
-template <typename TaskType>
-size_t Queue<TaskType>::NumEnqueuedTasks() const {
- mutex_lock l(mu_);
- size_t num_enqueued_tasks = 0;
- for (const auto& batch : batches_) {
- num_enqueued_tasks += batch->num_tasks();
- }
- return num_enqueued_tasks;
-}
-
-template <typename TaskType>
-size_t Queue<TaskType>::SchedulingCapacity() const {
- mutex_lock l(mu_);
- const int num_new_batches_schedulable =
- options_.max_enqueued_batches - batches_.size();
- const int open_batch_capacity =
- options_.max_batch_size - batches_.back()->size();
- return (num_new_batches_schedulable * options_.max_batch_size) +
- open_batch_capacity;
-}
-
-template <typename TaskType>
-std::unique_ptr<Batch<TaskType>> Queue<TaskType>::ScheduleBatch() {
- // The batch to schedule, which we may populate below. (If left as nullptr,
- // that means we are electing not to schedule a batch at this time.)
- std::unique_ptr<Batch<TaskType>> batch_to_schedule;
-
- {
- mutex_lock l(mu_);
-
- // Consider closing the open batch at this time, to schedule it.
- if (batches_.size() == 1 && IsOpenBatchSchedulable()) {
- StartNewBatch();
- }
-
- if (batches_.size() >= 2) {
- // There is at least one closed batch that is ready to be scheduled.
- ++num_batches_being_processed_;
- batch_to_schedule = std::move(batches_.front());
- batches_.pop_front();
- } else {
- schedulable_batch_ = false;
- }
- }
-
- return batch_to_schedule;
-}
-
-template <typename TaskType>
-void Queue<TaskType>::ProcessBatch(std::unique_ptr<Batch<TaskType>> batch) {
- process_batch_callback_(std::move(batch));
-
- {
- mutex_lock l(mu_);
- --num_batches_being_processed_;
- if (empty_notification_ != nullptr && IsEmptyInternal()) {
- empty_notification_->Notify();
- }
- }
-}
-
-template <typename TaskType>
-bool Queue<TaskType>::IsEmpty() const {
- mutex_lock l(mu_);
- return IsEmptyInternal();
-}
-
-template <typename TaskType>
-void Queue<TaskType>::CloseAndWaitUntilEmpty() {
- Notification empty;
- {
- mutex_lock l(mu_);
- closed_ = true;
- if (IsEmptyInternal()) {
- empty.Notify();
- } else {
- // Arrange for ProcessBatch() to notify when the queue becomes empty.
- empty_notification_ = &empty;
- }
- }
- empty.WaitForNotification();
-}
-
-template <typename TaskType>
-bool Queue<TaskType>::IsEmptyInternal() const {
- return num_batches_being_processed_ == 0 && batches_.size() == 1 &&
- batches_.back()->empty();
-}
-
-template <typename TaskType>
-void Queue<TaskType>::StartNewBatch() {
- batches_.back()->Close();
- batches_.emplace_back(new Batch<TaskType>);
-}
-
-template <typename TaskType>
-bool Queue<TaskType>::IsOpenBatchSchedulable() const {
- Batch<TaskType>* open_batch = batches_.back().get();
- if (open_batch->empty()) {
- return false;
- }
- return closed_ || open_batch->size() >= options_.max_batch_size ||
- env_->NowMicros() >=
- open_batch_start_time_micros_ + options_.batch_timeout_micros;
-}
-
-template <typename TaskType>
-QueueHandle<TaskType>::QueueHandle(
- std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler,
- Queue<TaskType>* queue)
- : scheduler_(scheduler), queue_(queue) {}
-
-template <typename TaskType>
-QueueHandle<TaskType>::~QueueHandle() {
- queue_->CloseAndWaitUntilEmpty();
-}
-
-template <typename TaskType>
-Status QueueHandle<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
- return queue_->Schedule(task);
-}
-
-template <typename TaskType>
-size_t QueueHandle<TaskType>::NumEnqueuedTasks() const {
- return queue_->NumEnqueuedTasks();
-}
-
-template <typename TaskType>
-size_t QueueHandle<TaskType>::SchedulingCapacity() const {
- return queue_->SchedulingCapacity();
-}
-
-} // namespace internal
-
-} // namespace serving
-} // namespace tensorflow
+#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc
deleted file mode 100644
index 3ac79a8fdc..0000000000
--- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc
+++ /dev/null
@@ -1,597 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/batching/shared_batch_scheduler.h"
-
-#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
-#include "tensorflow/core/lib/core/error_codes.pb.h"
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace serving {
-namespace {
-
-class FakeTask : public BatchTask {
- public:
- explicit FakeTask(size_t size) : size_(size) {}
-
- ~FakeTask() override = default;
-
- size_t size() const override { return size_; }
-
- private:
- const size_t size_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
-};
-
-// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
-// that task. Returns the resulting status.
-Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
- std::unique_ptr<FakeTask> task(new FakeTask(task_size));
- Status status = scheduler->Schedule(&task);
- // Schedule() should have consumed 'task' iff it returned Status::OK.
- CHECK_EQ(status.ok(), task == nullptr);
- return status;
-}
-
-// Creates a thread that waits on 'start' and then advances the fake clock in
-// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
-// use the clock to be destroyed.
-std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
- test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
- return std::unique_ptr<Thread>(
- Env::Default()->StartThread({}, "FakeClockAdvancerThread",
- [env, start, stop] {
- start->WaitForNotification();
- while (!stop->HasBeenNotified()) {
- env->AdvanceByMicroseconds(10);
- Env::Default()->SleepForMicroseconds(10);
- }
- }));
-}
-
-TEST(SharedBatchSchedulerTest, Basic) {
- for (int num_batch_threads : {1, 2, 3}) {
- for (const bool delete_scheduler_early : {false, true}) {
- for (const bool delete_queue_1_early : {false, true}) {
- bool queue_0_callback_called = false;
- auto queue_0_callback =
- [&queue_0_callback_called](std::unique_ptr<Batch<FakeTask>> batch) {
- queue_0_callback_called = true;
- ASSERT_TRUE(batch->IsClosed());
- ASSERT_EQ(3, batch->num_tasks());
- EXPECT_EQ(1, batch->task(0).size());
- EXPECT_EQ(3, batch->task(1).size());
- EXPECT_EQ(5, batch->task(2).size());
- };
- bool queue_1_callback_called = false;
- auto queue_1_callback =
- [&queue_1_callback_called](std::unique_ptr<Batch<FakeTask>> batch) {
- queue_1_callback_called = true;
- ASSERT_TRUE(batch->IsClosed());
- ASSERT_EQ(2, batch->num_tasks());
- EXPECT_EQ(2, batch->task(0).size());
- EXPECT_EQ(4, batch->task(1).size());
- };
- {
- SharedBatchScheduler<FakeTask>::Options options;
- options.num_batch_threads = num_batch_threads;
- std::shared_ptr<SharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(
- SharedBatchScheduler<FakeTask>::Create(options, &scheduler));
-
- // Create two queues.
- SharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 10;
- queue_options.batch_timeout_micros = 10 * 1000 * 1000; // 10 seconds
- queue_options.max_enqueued_batches = 2;
- std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
- TF_ASSERT_OK(
- scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
- std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
- TF_ASSERT_OK(
- scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
-
- if (delete_scheduler_early) {
- // Delete our copy of the scheduler. The queues should keep it alive
- // under the covers.
- scheduler = nullptr;
- }
-
- // Submit tasks to the two queues, and (optionally) remove the queues.
- TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
- TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
- TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
- TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
- if (delete_queue_1_early) {
- queue_1 = nullptr;
- }
- TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
- }
- EXPECT_TRUE(queue_0_callback_called);
- EXPECT_TRUE(queue_1_callback_called);
- }
- }
- }
-}
-
-TEST(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) {
- // Set up a callback that captures the batches' task sizes.
- mutex mu;
- std::vector<std::vector<size_t>> callback_data;
- auto callback = [&mu,
- &callback_data](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- std::vector<size_t> batch_data;
- batch_data.reserve(batch->num_tasks());
- for (int i = 0; i < batch->num_tasks(); ++i) {
- batch_data.push_back(batch->mutable_task(i)->size());
- }
- {
- mutex_lock l(mu);
- callback_data.push_back(batch_data);
- }
- };
-
- // Run a batch scheduler and inject some tasks.
- {
- SharedBatchScheduler<FakeTask>::Options options;
- options.num_batch_threads = 2;
- std::shared_ptr<SharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(SharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- SharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 10;
- queue_options.batch_timeout_micros = 10 * 1000 * 1000; // 10 seconds
- queue_options.max_enqueued_batches = 2;
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue));
-
- // First batch.
- TF_ASSERT_OK(ScheduleTask(3, queue.get()));
- TF_ASSERT_OK(ScheduleTask(5, queue.get()));
-
- // Second batch (due to size overage).
- TF_ASSERT_OK(ScheduleTask(3 /* (3+5) + 3 > 10 */, queue.get()));
- TF_ASSERT_OK(ScheduleTask(1, queue.get()));
- TF_ASSERT_OK(ScheduleTask(6, queue.get()));
-
- // (Empty third batch, since the second batch exactly hit the size limit,
- // which should never get sent to the callback.)
- }
-
- // Expect a certain grouping of the tasks into batches.
- ASSERT_EQ(2, callback_data.size());
- ASSERT_TRUE((callback_data[0].size() == 2 && callback_data[1].size() == 3) ||
- (callback_data[0].size() == 3 && callback_data[1].size() == 2));
- const std::vector<size_t>& callback_data_a =
- callback_data[0].size() == 2 ? callback_data[0] : callback_data[1];
- const std::vector<size_t>& callback_data_b =
- callback_data[0].size() == 2 ? callback_data[1] : callback_data[0];
- EXPECT_EQ((std::vector<size_t>{3, 5}), callback_data_a);
- EXPECT_EQ((std::vector<size_t>{3, 1, 6}), callback_data_b);
-}
-
-TEST(SharedBatchSchedulerTest, ObeysTimeout) {
- // Set up a fake clock, which only advances when we explicitly tell it to.
- test_util::FakeClockEnv env(Env::Default());
- Notification start_teardown, stop_teardown;
- std::unique_ptr<Thread> teardown_thread =
- CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
-
- {
- Notification first_batch_processed, second_batch_processed,
- third_batch_processed;
- auto callback =
- [&first_batch_processed, &second_batch_processed,
- &third_batch_processed](std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- if (batch->size() == 1) {
- first_batch_processed.Notify();
- } else if (batch->size() == 2) {
- second_batch_processed.Notify();
- } else if (batch->size() == 3) {
- third_batch_processed.Notify();
- } else {
- EXPECT_TRUE(false) << "Unexpected batch size";
- }
- };
-
- SharedBatchScheduler<FakeTask>::Options options;
- options.num_batch_threads = 1;
- options.env = &env;
- std::shared_ptr<SharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(SharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- SharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 4;
- queue_options.batch_timeout_micros = 10;
- queue_options.max_enqueued_batches = 2;
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue));
-
- // Create an underfull batch, and ensure that it gets processed when the
- // clock hits the timeout.
- TF_ASSERT_OK(ScheduleTask(1, queue.get()));
- env.AdvanceByMicroseconds(9);
- Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
- EXPECT_FALSE(first_batch_processed.HasBeenNotified());
- env.AdvanceByMicroseconds(1);
- first_batch_processed.WaitForNotification();
-
- // Start creating a batch, while leaving the clock well below the timeout.
- // Then submit a new task that overflows into the next batch, causing
- // the original batch to close.
- TF_ASSERT_OK(ScheduleTask(2, queue.get()));
- Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
- EXPECT_FALSE(second_batch_processed.HasBeenNotified());
- TF_ASSERT_OK(ScheduleTask(3, queue.get()));
- second_batch_processed.WaitForNotification();
-
- // Allow the third batch to hit its timeout, and ensure it gets closed at
- // the right time.
- env.AdvanceByMicroseconds(9);
- Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
- EXPECT_FALSE(third_batch_processed.HasBeenNotified());
- env.AdvanceByMicroseconds(1);
- third_batch_processed.WaitForNotification();
-
- start_teardown.Notify();
- }
- stop_teardown.Notify();
-}
-
-TEST(SharedBatchSchedulerTest, ObeysTimeoutWithRealClock) {
- Notification first_batch_processed, second_batch_processed;
- auto callback = [&first_batch_processed, &second_batch_processed](
- std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- if (batch->size() == 1) {
- first_batch_processed.Notify();
- } else if (batch->size() == 2) {
- second_batch_processed.Notify();
- } else {
- EXPECT_TRUE(false) << "Unexpected batch size";
- }
- };
-
- SharedBatchScheduler<FakeTask>::Options options;
- options.num_batch_threads = 2;
- std::shared_ptr<SharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(SharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- SharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 10;
- queue_options.batch_timeout_micros = 100 * 1000; // 100 milliseconds
- queue_options.max_enqueued_batches = 2;
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue));
-
- // Submit a single task that doesn't fill up the batch.
- // Ensure that it gets processed due to the timeout.
- TF_ASSERT_OK(ScheduleTask(1, queue.get()));
- first_batch_processed.WaitForNotification();
-
- // Do it again.
- TF_ASSERT_OK(ScheduleTask(2, queue.get()));
- second_batch_processed.WaitForNotification();
-}
-
-TEST(SharedBatchSchedulerTest,
- WithZeroTimeoutBatchesScheduledAsSoonAsThreadIsAvailable) {
- // Set up a fake clock, and never advance the time.
- test_util::FakeClockEnv env(Env::Default());
- Notification start_teardown, stop_teardown;
- std::unique_ptr<Thread> teardown_thread =
- CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
-
- {
- Notification first_batch_processed, second_batch_processed;
- auto callback = [&first_batch_processed, &second_batch_processed](
- std::unique_ptr<Batch<FakeTask>> batch) {
- ASSERT_TRUE(batch->IsClosed());
- if (batch->size() == 1) {
- first_batch_processed.Notify();
- } else if (batch->size() == 2) {
- second_batch_processed.Notify();
- } else {
- EXPECT_TRUE(false) << "Unexpected batch size";
- }
- };
-
- SharedBatchScheduler<FakeTask>::Options options;
- options.num_batch_threads = 2;
- options.env = &env;
- std::shared_ptr<SharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(SharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- SharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- // Set a large batch size, so that we don't hit the batch size limit.
- queue_options.max_batch_size = 100;
- // Process a batch as soon as a thread is available.
- queue_options.batch_timeout_micros = 0;
- queue_options.max_enqueued_batches = 2;
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue));
-
- TF_ASSERT_OK(ScheduleTask(1, queue.get()));
- first_batch_processed.WaitForNotification();
- TF_ASSERT_OK(ScheduleTask(2, queue.get()));
- second_batch_processed.WaitForNotification();
-
- // Shut everything down.
- start_teardown.Notify();
- }
- stop_teardown.Notify();
-}
-
-TEST(SharedBatchSchedulerTest, Fairness) {
- test_util::FakeClockEnv env(Env::Default());
- Notification start_teardown, stop_teardown;
- std::unique_ptr<Thread> teardown_thread =
- CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
-
- {
- Notification queue_0_first_batch_scheduled, queue_0_first_batch_proceed,
- queue_0_second_batch_scheduled;
- auto queue_0_callback = [&queue_0_first_batch_scheduled,
- &queue_0_first_batch_proceed,
- &queue_0_second_batch_scheduled](
- std::unique_ptr<Batch<FakeTask>> batch) {
- if (!queue_0_first_batch_scheduled.HasBeenNotified()) {
- queue_0_first_batch_scheduled.Notify();
- queue_0_first_batch_proceed.WaitForNotification();
- } else if (!queue_0_second_batch_scheduled.HasBeenNotified()) {
- queue_0_second_batch_scheduled.Notify();
- }
- };
-
- Notification queue_1_first_batch_scheduled, queue_1_first_batch_proceed;
- auto queue_1_callback =
- [&queue_1_first_batch_scheduled,
- &queue_1_first_batch_proceed](std::unique_ptr<Batch<FakeTask>> batch) {
- queue_1_first_batch_scheduled.Notify();
- queue_1_first_batch_proceed.WaitForNotification();
- };
-
- SharedBatchScheduler<FakeTask>::Options options;
- options.num_batch_threads = 1;
- options.env = &env;
- std::shared_ptr<SharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(SharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- SharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 10;
- queue_options.batch_timeout_micros = 1;
- queue_options.max_enqueued_batches = 100 /* give plenty of room */;
- std::vector<std::unique_ptr<BatchScheduler<FakeTask>>> queues(2);
- TF_ASSERT_OK(
- scheduler->AddQueue(queue_options, queue_0_callback, &queues[0]));
- TF_ASSERT_OK(
- scheduler->AddQueue(queue_options, queue_1_callback, &queues[1]));
-
- // Enqueue a batch-filling task to queue 0, and wait for it to get
- // scheduled.
- TF_ASSERT_OK(ScheduleTask(10, queues[0].get()));
- env.AdvanceByMicroseconds(1);
- queue_0_first_batch_scheduled.WaitForNotification();
-
- // Enqueue two more batch-filling tasks to queue 0.
- TF_ASSERT_OK(ScheduleTask(10, queues[0].get()));
- TF_ASSERT_OK(ScheduleTask(10, queues[0].get()));
-
- // Enqueue one task to queue 1, and then advance the clock so it becomes
- // eligible for scheduling due to the timeout. Ensure that the queue 1 batch
- // gets scheduled before the next queue 0 one.
- TF_ASSERT_OK(ScheduleTask(1, queues[1].get()));
- env.AdvanceByMicroseconds(1);
- queue_0_first_batch_proceed.Notify();
- queue_1_first_batch_scheduled.WaitForNotification();
- Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
- EXPECT_FALSE(queue_0_second_batch_scheduled.HasBeenNotified());
-
- // Shut everything down.
- queue_1_first_batch_proceed.Notify();
- start_teardown.Notify();
- }
- stop_teardown.Notify();
-}
-
-TEST(SharedBatchSchedulerTest, ConstMethods) {
- for (const int max_enqueued_batches : {1, 2, 5}) {
- Notification processing, proceed;
- auto callback = [&processing,
- &proceed](std::unique_ptr<Batch<FakeTask>> batch) {
- if (!processing.HasBeenNotified()) {
- processing.Notify();
- }
- proceed.WaitForNotification();
- };
-
- SharedBatchScheduler<FakeTask>::Options options;
- options.num_batch_threads = 1;
- std::shared_ptr<SharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(SharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- SharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 2;
- queue_options.batch_timeout_micros = 0;
- queue_options.max_enqueued_batches = max_enqueued_batches;
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue));
- EXPECT_EQ(2, queue->max_task_size());
- EXPECT_EQ(0, queue->NumEnqueuedTasks());
- EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity());
-
- // Get one batch going on the thread, and keep the thread blocked until
- // we're done testing the maximum queue length.
- TF_ASSERT_OK(ScheduleTask(2, queue.get()));
- processing.WaitForNotification();
- EXPECT_EQ(0, queue->NumEnqueuedTasks());
-
- // We should be able to enqueue 'max_enqueued_batches'*2 tasks without
- // issue.
- for (int i = 0; i < max_enqueued_batches; ++i) {
- EXPECT_EQ(i * 2, queue->NumEnqueuedTasks());
- EXPECT_EQ((max_enqueued_batches - i) * 2, queue->SchedulingCapacity());
- TF_ASSERT_OK(ScheduleTask(1, queue.get()));
- EXPECT_EQ((i * 2) + 1, queue->NumEnqueuedTasks());
- EXPECT_EQ((max_enqueued_batches - i) * 2 - 1,
- queue->SchedulingCapacity());
- TF_ASSERT_OK(ScheduleTask(1, queue.get()));
- }
- EXPECT_EQ(max_enqueued_batches * 2, queue->NumEnqueuedTasks());
- EXPECT_EQ(0, queue->SchedulingCapacity());
-
- // Attempting to enqueue one more task should yield an UNAVAILABLE error.
- Status status = ScheduleTask(1, queue.get());
- ASSERT_FALSE(status.ok());
- EXPECT_EQ(error::UNAVAILABLE, status.code());
- EXPECT_EQ(max_enqueued_batches * 2, queue->NumEnqueuedTasks());
- EXPECT_EQ(0, queue->SchedulingCapacity());
-
- proceed.Notify();
- }
-}
-
-TEST(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) {
- Notification queue_0_processing, queue_0_proceed;
- auto queue_0_callback = [&queue_0_processing, &queue_0_proceed](
- std::unique_ptr<Batch<FakeTask>> batch) {
- if (!queue_0_processing.HasBeenNotified()) {
- queue_0_processing.Notify();
- queue_0_proceed.WaitForNotification();
- }
- };
-
- Notification queue_1_first_batch_processed, queue_1_second_batch_processed,
- queue_1_third_batch_processed;
- auto queue_1_callback =
- [&queue_1_first_batch_processed, &queue_1_second_batch_processed,
- &queue_1_third_batch_processed](std::unique_ptr<Batch<FakeTask>> batch) {
- if (batch->size() == 1) {
- queue_1_first_batch_processed.Notify();
- } else if (batch->size() == 2) {
- queue_1_second_batch_processed.Notify();
- } else if (batch->size() == 3) {
- queue_1_third_batch_processed.Notify();
- } else {
- EXPECT_TRUE(false) << "Unexpected batch size";
- }
- };
-
- SharedBatchScheduler<FakeTask>::Options options;
- options.num_batch_threads = 2;
- std::shared_ptr<SharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(SharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- SharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 10;
- queue_options.batch_timeout_micros = 0;
- queue_options.max_enqueued_batches = 2;
- std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
- TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
- std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
- TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
-
- // Clog up queue 0.
- TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
- queue_0_processing.WaitForNotification();
- Status queue_0_status;
- do {
- queue_0_status = ScheduleTask(1, queue_0.get());
- } while (queue_0_status.ok());
- EXPECT_EQ(error::UNAVAILABLE, queue_0_status.code());
-
- // Ensure that queue 1 still behaves normally, and lets us process tasks.
- TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
- queue_1_first_batch_processed.WaitForNotification();
- TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
- queue_1_second_batch_processed.WaitForNotification();
- TF_ASSERT_OK(ScheduleTask(3, queue_1.get()));
- queue_1_third_batch_processed.WaitForNotification();
-
- // Let poor queue 0 drain.
- queue_0_proceed.Notify();
-}
-
-TEST(SharedBatchSchedulerTest, QueueDestructorBlocksUntilAllTasksProcessed) {
- test_util::FakeClockEnv env(Env::Default());
- Notification start_teardown, stop_teardown;
- std::unique_ptr<Thread> teardown_thread =
- CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
-
- {
- int current_batch = 0;
- Notification first_callback_started;
- const int kMaxEnqueuedBatches = 3;
- std::vector<Notification> callback_proceed(kMaxEnqueuedBatches);
- auto callback =
- [&current_batch, &first_callback_started,
- &callback_proceed](std::unique_ptr<Batch<FakeTask>> batch) {
- if (current_batch == 0) {
- first_callback_started.Notify();
- }
- callback_proceed[current_batch].WaitForNotification();
- ++current_batch;
- };
-
- SharedBatchScheduler<FakeTask>::Options options;
- options.num_batch_threads = 1;
- options.env = &env;
- std::shared_ptr<SharedBatchScheduler<FakeTask>> scheduler;
- TF_ASSERT_OK(SharedBatchScheduler<FakeTask>::Create(options, &scheduler));
- SharedBatchScheduler<FakeTask>::QueueOptions queue_options;
- queue_options.max_batch_size = 10;
- queue_options.batch_timeout_micros = 0;
- queue_options.max_enqueued_batches = 2;
- std::unique_ptr<BatchScheduler<FakeTask>> queue;
- TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue));
-
- // Clog up the queue.
- int num_enqueued_batches = 0;
- TF_ASSERT_OK(ScheduleTask(10, queue.get()));
- ++num_enqueued_batches;
- env.AdvanceByMicroseconds(1);
- first_callback_started.WaitForNotification();
- for (int i = 0; i < 2; ++i) {
- TF_ASSERT_OK(ScheduleTask(10, queue.get()));
- ++num_enqueued_batches;
- }
- EXPECT_EQ(kMaxEnqueuedBatches, num_enqueued_batches);
- EXPECT_EQ(error::UNAVAILABLE, ScheduleTask(10, queue.get()).code());
-
- // Destroy the queue. The destructor should block until all tasks have been
- // processed.
- Notification destroy_queue_thread_started, queue_destroyed;
- std::unique_ptr<Thread> destroy_queue_thread(Env::Default()->StartThread(
- {}, "DestroyQueueThread",
- [&queue, &destroy_queue_thread_started, &queue_destroyed] {
- destroy_queue_thread_started.Notify();
- queue = nullptr;
- queue_destroyed.Notify();
- }));
- destroy_queue_thread_started.WaitForNotification();
- for (int i = 0; i < num_enqueued_batches; ++i) {
- Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
- EXPECT_FALSE(queue_destroyed.HasBeenNotified());
- callback_proceed[i].Notify();
- }
-
- start_teardown.Notify();
- }
- stop_teardown.Notify();
-}
-
-} // namespace
-} // namespace serving
-} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD
index d1ced0d8c3..6db627faad 100644
--- a/tensorflow/contrib/batching/test_util/BUILD
+++ b/tensorflow/contrib/batching/test_util/BUILD
@@ -22,11 +22,9 @@ filegroup(
cc_library(
name = "fake_clock_env",
testonly = 1,
- srcs = ["fake_clock_env.cc"],
hdrs = ["fake_clock_env.h"],
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/core:lib",
- "//tensorflow/core:tensorflow",
+ "//tensorflow/core/kernels/batching_util:fake_clock_env",
],
)
diff --git a/tensorflow/contrib/batching/test_util/fake_clock_env.cc b/tensorflow/contrib/batching/test_util/fake_clock_env.cc
deleted file mode 100644
index 166d6703bd..0000000000
--- a/tensorflow/contrib/batching/test_util/fake_clock_env.cc
+++ /dev/null
@@ -1,90 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
-
-#include <string>
-
-namespace tensorflow {
-namespace serving {
-namespace test_util {
-
-FakeClockEnv::FakeClockEnv(Env* wrapped) : EnvWrapper(wrapped) {}
-
-void FakeClockEnv::AdvanceByMicroseconds(int micros) {
- {
- mutex_lock l(mu_);
- current_time_ += micros;
- for (auto it = sleeping_threads_.begin(); it != sleeping_threads_.end();) {
- if (current_time_ >= it->wake_time) {
- it->wake_notification->Notify();
- it = sleeping_threads_.erase(it);
- } else {
- ++it;
- }
- }
- }
-}
-
-void FakeClockEnv::BlockUntilSleepingThread(uint64 wake_time) {
- for (;;) {
- {
- mutex_lock l(mu_);
- for (auto it = sleeping_threads_.begin(); it != sleeping_threads_.end();
- ++it) {
- if (it->wake_time == wake_time) {
- return;
- }
- }
- }
- EnvWrapper::SleepForMicroseconds(100);
- }
-}
-
-void FakeClockEnv::BlockUntilThreadsAsleep(int num_threads) {
- for (;;) {
- {
- mutex_lock l(mu_);
- if (num_threads <= sleeping_threads_.size()) {
- return;
- }
- }
- EnvWrapper::SleepForMicroseconds(100);
- }
-}
-
-uint64 FakeClockEnv::NowMicros() {
- {
- mutex_lock l(mu_);
- return current_time_;
- }
-}
-
-void FakeClockEnv::SleepForMicroseconds(int64 micros) {
- if (micros == 0) {
- return;
- }
-
- Notification wake_notification;
- {
- mutex_lock l(mu_);
- sleeping_threads_.push_back({current_time_ + micros, &wake_notification});
- }
- wake_notification.WaitForNotification();
-}
-
-} // namespace test_util
-} // namespace serving
-} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/test_util/fake_clock_env.h b/tensorflow/contrib/batching/test_util/fake_clock_env.h
index 35cafcb73c..ced27a8833 100644
--- a/tensorflow/contrib/batching/test_util/fake_clock_env.h
+++ b/tensorflow/contrib/batching/test_util/fake_clock_env.h
@@ -16,61 +16,6 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
-#include <functional>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace serving {
-namespace test_util {
-
-// An Env implementation with a fake clock for NowMicros() and
-// SleepForMicroseconds(). The clock doesn't advance on its own; it advances via
-// an explicit Advance() method.
-// All other Env virtual methods pass through to a wrapped Env.
-class FakeClockEnv : public EnvWrapper {
- public:
- explicit FakeClockEnv(Env* wrapped);
- ~FakeClockEnv() override = default;
-
- // Advance the clock by a certain number of microseconds.
- void AdvanceByMicroseconds(int micros);
-
- // Blocks until there is a sleeping thread that is scheduled to wake up at
- // the given (absolute) time.
- void BlockUntilSleepingThread(uint64 wake_time);
-
- // Blocks until there are at least num_threads sleeping.
- void BlockUntilThreadsAsleep(int num_threads);
-
- // Methods that this class implements.
- uint64 NowMicros() override;
- void SleepForMicroseconds(int64 micros) override;
-
- private:
- mutex mu_;
-
- uint64 current_time_ GUARDED_BY(mu_) = 0;
-
- struct SleepingThread {
- uint64 wake_time;
- Notification* wake_notification;
- };
- std::vector<SleepingThread> sleeping_threads_ GUARDED_BY(mu_);
-
- TF_DISALLOW_COPY_AND_ASSIGN(FakeClockEnv);
-};
-
-} // namespace test_util
-} // namespace serving
-} // namespace tensorflow
+#include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD
index f33a08cb81..0df7d456da 100644
--- a/tensorflow/contrib/batching/util/BUILD
+++ b/tensorflow/contrib/batching/util/BUILD
@@ -22,12 +22,10 @@ filegroup(
cc_library(
name = "periodic_function_dynamic",
- srcs = ["periodic_function.cc"],
hdrs = ["periodic_function.h"],
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/core:framework_headers_lib",
- "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
],
)
@@ -36,17 +34,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":periodic_function_dynamic",
- "//tensorflow/core:lib",
- ],
-)
-
-tf_cc_test(
- name = "periodic_function_test",
- srcs = ["periodic_function_test.cc"],
- deps = [
- ":periodic_function_dynamic",
- "//tensorflow/contrib/batching/test_util:fake_clock_env",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
+ "//tensorflow/core/kernels/batching_util:periodic_function",
],
)
diff --git a/tensorflow/contrib/batching/util/periodic_function.cc b/tensorflow/contrib/batching/util/periodic_function.cc
deleted file mode 100644
index b7e4838da5..0000000000
--- a/tensorflow/contrib/batching/util/periodic_function.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/batching/util/periodic_function.h"
-
-#include <algorithm>
-
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace tensorflow {
-namespace serving {
-
-PeriodicFunction::PeriodicFunction(const std::function<void()>& function,
- const int64 interval_micros,
- const Options& options)
- : function_(function),
- interval_micros_([interval_micros]() -> int64 {
- if (interval_micros < 0) {
- const string error = strings::StrCat(
- " The value of 'interval_micros' should be >= 0: ",
- interval_micros, ". ");
- DCHECK(false) << error;
- LOG(WARNING) << error << "Resetting it to 0.";
- return 0;
- }
- return interval_micros;
- }()),
- options_(options) {
- thread_.reset(options_.env->StartThread(
- options_.thread_options, options_.thread_name_prefix, [this]() {
- // Record the starting time here instead of in RunLoop. That way, if
- // there is a delay starting RunLoop, that does not affect the timing
- // of
- // the first function. (Such a delay can often happen in tests where
- // the test simulates a large time delay immediately after calling
- // Start.)
- RunLoop(options_.env->NowMicros());
- }));
-}
-
-PeriodicFunction::~PeriodicFunction() {
- NotifyStop();
-
- // Waits for thread_ to complete and clean up.
- thread_.reset();
-}
-
-void PeriodicFunction::NotifyStop() {
- if (!stop_thread_.HasBeenNotified()) {
- stop_thread_.Notify();
- }
-}
-
-void PeriodicFunction::RunLoop(const int64 start) {
- {
- if (options_.startup_delay_micros > 0) {
- const int64 deadline = start + options_.startup_delay_micros;
- options_.env->SleepForMicroseconds(deadline - start);
- }
-
- while (!stop_thread_.HasBeenNotified()) {
- VLOG(3) << "Running function.";
- const int64 begin = options_.env->NowMicros();
- function_();
-
- // Take the max() here to guard against time going backwards which
- // sometimes happens in multiproc machines.
- const int64 end =
- std::max(static_cast<int64>(options_.env->NowMicros()), begin);
-
- // The deadline is relative to when the last function started.
- const int64 deadline = begin + interval_micros_;
-
- // We want to sleep until 'deadline'.
- if (deadline > end) {
- if (end > begin) {
- VLOG(3) << "Reducing interval_micros from " << interval_micros_
- << " to " << (deadline - end);
- }
- options_.env->SleepForMicroseconds(deadline - end);
- } else {
- VLOG(3) << "Function took longer than interval_micros, so not sleeping";
- }
- }
- }
-}
-
-} // namespace serving
-} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/util/periodic_function.h b/tensorflow/contrib/batching/util/periodic_function.h
index 2c032d802f..fb61bc2eea 100644
--- a/tensorflow/contrib/batching/util/periodic_function.h
+++ b/tensorflow/contrib/batching/util/periodic_function.h
@@ -12,121 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-// PeriodicFunction will periodically call the given function with a specified
-// period in a background thread. After Start() returns, the thread is
-// guaranteed to have started. The destruction of the class causes the
-// background thread to be destroyed as well. Start() should not be called more
-// than once.
-//
-// PeriodicFunction runs the function as soon as any previous run both is
-// complete and was started more than "interval_micros" earlier. Thus, runs are
-// both serialized, and normally have a period of "interval_micros" if no run
-// exceeds the time.
-//
-// Note that, if the function takes longer than two interval_micross to finish,
-// then PeriodicFunction will "skip" at least one call to the function. For
-// instance, if the period is 50ms and the function starts runs at time 0 for
-// 150ms, then the function will immediately start executing again at time 150,
-// but there will be no function runs corresponding to times 50 or 100. This is
-// especially important to remember when using an environment with a simulated
-// clock: advancing simulated time atomically over N interval_micross will not
-// cause the function to be called N times.
-//
-// This object is thread-safe.
-//
-// Example:
-//
-// class Foo {
-// public:
-// Foo() : periodic_function_([this]() { Bar(); },
-// 1000 /* 1000us == 1ms*/) {
-// }
-//
-// private:
-// void Bar() { ... }
-//
-// PeriodicFunction periodic_function_;
-// };
-
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
-#include <functional>
-#include <memory>
-#include <string>
-
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace serving {
-
-namespace internal {
-class PeriodicFunctionTestAccess;
-}
-
-class PeriodicFunction {
- public:
- // Provides the ability to customize several aspects of the PeriodicFunction.
- // Passed to constructor of PeriodicFunction.
- struct Options {
- Options() {}
-
- // Any standard thread options, such as stack size, should
- // be passed via "thread_options".
- ThreadOptions thread_options;
-
- // Specifies the thread name prefix (see the description in class
- // Thread).
- string thread_name_prefix = "periodic_function";
-
- // The environment to use. Does not take ownership, but must remain alive
- // for as long as the PeriodicFunction exists.
- Env* env = Env::Default();
-
- // Specifies the length of sleep before the first invocation of the
- // function.
- // This can be used for adding a random jitter to avoid synchronous behavior
- // across multiple periodic functions.
- int64 startup_delay_micros = 0;
- };
-
- // Also starts the background thread which will be calling the function.
- PeriodicFunction(const std::function<void()>& function, int64 interval_micros,
- const Options& options = Options());
-
- ~PeriodicFunction();
-
- private:
- friend class internal::PeriodicFunctionTestAccess;
-
- // Notifies the background thread to stop.
- void NotifyStop();
-
- // (Blocking.) Loops forever calling "function_" every "interval_micros_".
- void RunLoop(int64 start) LOCKS_EXCLUDED(mutex_);
-
- const std::function<void()> function_; // Actual client function
- const int64 interval_micros_; // Interval between calls.
- const Options options_;
-
- // Protects state below.
- mutable mutex mutex_;
- // Used to notify the thread to stop.
- Notification stop_thread_;
-
- // Thread for running "function_"
- std::unique_ptr<Thread> thread_ GUARDED_BY(mutex_) = nullptr;
-
- TF_DISALLOW_COPY_AND_ASSIGN(PeriodicFunction);
-};
-
-} // namespace serving
-} // namespace tensorflow
+#include "tensorflow/core/kernels/batching_util/periodic_function.h"
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
diff --git a/tensorflow/contrib/batching/util/periodic_function_test.cc b/tensorflow/contrib/batching/util/periodic_function_test.cc
deleted file mode 100644
index 1517961116..0000000000
--- a/tensorflow/contrib/batching/util/periodic_function_test.cc
+++ /dev/null
@@ -1,225 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/batching/util/periodic_function.h"
-
-#include <memory>
-#include <string>
-
-#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace serving {
-
-namespace internal {
-
-class PeriodicFunctionTestAccess {
- public:
- explicit PeriodicFunctionTestAccess(PeriodicFunction* periodic_function)
- : periodic_function_(periodic_function) {}
-
- void NotifyStop() { periodic_function_->NotifyStop(); }
-
- private:
- PeriodicFunction* const periodic_function_;
-};
-
-} // namespace internal
-
-namespace {
-
-using test_util::FakeClockEnv;
-
-void StopPeriodicFunction(PeriodicFunction* periodic_function,
- FakeClockEnv* fake_clock_env,
- const uint64 pf_interval_micros) {
- fake_clock_env->BlockUntilThreadsAsleep(1);
- internal::PeriodicFunctionTestAccess(periodic_function).NotifyStop();
- fake_clock_env->AdvanceByMicroseconds(pf_interval_micros);
-}
-
-TEST(PeriodicFunctionTest, ObeyInterval) {
- const int64 kPeriodMicros = 2;
- const int kCalls = 10;
-
- int actual_calls = 0;
- {
- FakeClockEnv fake_clock_env(Env::Default());
- PeriodicFunction::Options options;
- options.env = &fake_clock_env;
- PeriodicFunction periodic_function([&actual_calls]() { ++actual_calls; },
- kPeriodMicros, options);
-
- for (int i = 0; i < kCalls; ++i) {
- fake_clock_env.BlockUntilThreadsAsleep(1);
- fake_clock_env.AdvanceByMicroseconds(kPeriodMicros);
- }
- StopPeriodicFunction(&periodic_function, &fake_clock_env, kPeriodMicros);
- }
-
- // The function gets called kCalls+1 times: once at time 0, once at time
- // kPeriodMicros, once at time kPeriodMicros*2, up to once at time
- // kPeriodMicros*kCalls.
- ASSERT_EQ(actual_calls, kCalls + 1);
-}
-
-TEST(PeriodicFunctionTest, ObeyStartupDelay) {
- const int64 kDelayMicros = 10;
- const int64 kPeriodMicros = kDelayMicros / 10;
-
- int actual_calls = 0;
- {
- PeriodicFunction::Options options;
- options.startup_delay_micros = kDelayMicros;
- FakeClockEnv fake_clock_env(Env::Default());
- options.env = &fake_clock_env;
- PeriodicFunction periodic_function([&actual_calls]() { ++actual_calls; },
- kPeriodMicros, options);
-
- // Wait for the thread to start up.
- fake_clock_env.BlockUntilThreadsAsleep(1);
- // Function shouldn't have been called yet.
- EXPECT_EQ(0, actual_calls);
- // Give enough time for startup delay to expire.
- fake_clock_env.AdvanceByMicroseconds(kDelayMicros);
- StopPeriodicFunction(&periodic_function, &fake_clock_env, kDelayMicros);
- }
-
- // Function should have been called at least once.
- EXPECT_EQ(1, actual_calls);
-}
-
-// Test for race in calculating the first time the callback should fire.
-TEST(PeriodicFunctionTest, StartupDelayRace) {
- const int64 kDelayMicros = 10;
- const int64 kPeriodMicros = kDelayMicros / 10;
-
- mutex mu;
- int counter = 0;
- std::unique_ptr<Notification> listener(new Notification);
-
- FakeClockEnv fake_clock_env(Env::Default());
- PeriodicFunction::Options options;
- options.env = &fake_clock_env;
- options.startup_delay_micros = kDelayMicros;
- PeriodicFunction periodic_function(
- [&mu, &counter, &listener]() {
- mutex_lock l(mu);
- counter++;
- listener->Notify();
- },
- kPeriodMicros, options);
-
- fake_clock_env.BlockUntilThreadsAsleep(1);
- fake_clock_env.AdvanceByMicroseconds(kDelayMicros);
- listener->WaitForNotification();
- {
- mutex_lock l(mu);
- EXPECT_EQ(1, counter);
- // A notification can only be notified once.
- listener.reset(new Notification);
- }
- fake_clock_env.BlockUntilThreadsAsleep(1);
- fake_clock_env.AdvanceByMicroseconds(kPeriodMicros);
- listener->WaitForNotification();
- {
- mutex_lock l(mu);
- EXPECT_EQ(2, counter);
- }
- StopPeriodicFunction(&periodic_function, &fake_clock_env, kPeriodMicros);
-}
-
-// If this test hangs forever, its probably a deadlock caused by setting the
-// PeriodicFunction's interval to 0ms.
-TEST(PeriodicFunctionTest, MinInterval) {
- PeriodicFunction periodic_function(
- []() { Env::Default()->SleepForMicroseconds(20 * 1000); }, 0);
-}
-
-class PeriodicFunctionWithFakeClockEnvTest : public ::testing::Test {
- protected:
- const int64 kPeriodMicros = 50;
- PeriodicFunctionWithFakeClockEnvTest()
- : fake_clock_env_(Env::Default()),
- counter_(0),
- pf_(
- [this]() {
- mutex_lock l(counter_mu_);
- ++counter_;
- },
- kPeriodMicros, GetPeriodicFunctionOptions()) {}
-
- PeriodicFunction::Options GetPeriodicFunctionOptions() {
- PeriodicFunction::Options options;
- options.thread_name_prefix = "ignore";
- options.env = &fake_clock_env_;
- return options;
- }
-
- void SetUp() override {
- // Note: counter_ gets initially incremented at time 0.
- ASSERT_TRUE(AwaitCount(1));
- }
-
- void TearDown() override {
- StopPeriodicFunction(&pf_, &fake_clock_env_, kPeriodMicros);
- }
-
- // The FakeClockEnv tests below advance simulated time and then expect the
- // PeriodicFunction thread to run its function. This method helps the tests
- // wait for the thread to execute, and then check the count matches the
- // expectation.
- bool AwaitCount(int expected_counter) {
- fake_clock_env_.BlockUntilThreadsAsleep(1);
- {
- mutex_lock lock(counter_mu_);
- return counter_ == expected_counter;
- }
- }
-
- FakeClockEnv fake_clock_env_;
- mutex counter_mu_;
- int counter_;
- PeriodicFunction pf_;
-};
-
-TEST_F(PeriodicFunctionWithFakeClockEnvTest, FasterThanRealTime) {
- fake_clock_env_.AdvanceByMicroseconds(kPeriodMicros / 2);
- for (int i = 2; i < 7; ++i) {
- fake_clock_env_.AdvanceByMicroseconds(
- kPeriodMicros); // advance past a tick
- EXPECT_TRUE(AwaitCount(i));
- }
-}
-
-TEST_F(PeriodicFunctionWithFakeClockEnvTest, SlowerThanRealTime) {
- Env::Default()->SleepForMicroseconds(
- 125 * 1000); // wait for any unexpected breakage
- EXPECT_TRUE(AwaitCount(1));
-}
-
-TEST(PeriodicFunctionDeathTest, BadInterval) {
- EXPECT_DEBUG_DEATH(PeriodicFunction periodic_function([]() {}, -1),
- ".* should be >= 0");
-
- EXPECT_DEBUG_DEATH(PeriodicFunction periodic_function(
- []() {}, -1, PeriodicFunction::Options()),
- ".* should be >= 0");
-}
-
-} // namespace
-} // namespace serving
-} // namespace tensorflow