From c9bdd3938e2b43334a0065b4c198ec9d491c8cb8 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 3 Oct 2018 10:04:37 -0700 Subject: [tf.data] Switch background threads to use `BackgroundWorker`. PiperOrigin-RevId: 215579950 --- tensorflow/core/kernels/data/iterator_ops.cc | 4 ---- .../core/kernels/data/map_and_batch_dataset_op.cc | 10 ++++---- tensorflow/core/kernels/data/model_dataset_op.cc | 10 ++++---- .../kernels/data/parallel_interleave_dataset_op.cc | 27 +++++++++++++--------- .../core/kernels/data/parallel_map_iterator.cc | 10 ++++---- .../core/kernels/data/prefetch_dataset_op.cc | 10 ++++---- tensorflow/core/kernels/data/writer_ops.cc | 12 +++++----- 7 files changed, 46 insertions(+), 37 deletions(-) (limited to 'tensorflow/core/kernels') diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 7a833668ac..8acd6cc724 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -16,10 +16,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/renamed_device.h" -#include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" @@ -27,13 +25,11 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/ops_util.h" -#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/public/session_options.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index bf08970560..6a670f1efb 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -405,9 +406,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr ctx_copy(new IteratorContext(*ctx)); - runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", - std::bind(&Iterator::RunnerThread, this, ctx_copy))); + runner_thread_ = + MakeUnique(ctx->env(), "runner_thread"); + runner_thread_->Schedule( + std::bind(&Iterator::RunnerThread, this, ctx_copy)); } } @@ -660,7 +662,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr input_impl_; // Buffer for storing the (intermediate) batch results. std::deque> batch_results_ GUARDED_BY(*mu_); - std::unique_ptr runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr runner_thread_ GUARDED_BY(*mu_); bool cancelled_ GUARDED_BY(*mu_) = false; }; diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 9aa505f4f1..859df57962 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -126,9 +127,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!optimize_thread_) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); - optimize_thread_.reset(ctx->env()->StartThread( - {}, "optimize_thread", - [this, new_ctx]() { OptimizeThread(new_ctx); })); + optimize_thread_ = + MakeUnique(ctx->env(), "optimize_thread"); + optimize_thread_->Schedule( + [this, new_ctx]() { OptimizeThread(new_ctx); }); } return Status::OK(); } @@ -167,7 +169,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { mutex mu_; condition_variable cond_var_; std::shared_ptr model_; - std::unique_ptr optimize_thread_ GUARDED_BY(mu_); + std::unique_ptr optimize_thread_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_) = false; std::unique_ptr input_impl_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 6b6b3d6ab9..9c836b836e 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -481,9 +482,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, "worker_thread", - [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); + worker_threads_.emplace_back( + MakeUnique(ctx->env(), "worker_thread")); + worker_threads_.back()->Schedule( + [this, new_ctx, i]() { WorkerThread(new_ctx, i); }); } } return Status::OK(); @@ -580,9 +582,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } workers_[i].SetInputs(s, std::move(args)); std::shared_ptr new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, "worker_thread", - [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); + worker_threads_.emplace_back( + MakeUnique(ctx->env(), "worker_thread")); + worker_threads_.back()->Schedule( + [this, new_ctx, i]() { WorkerThread(new_ctx, i); }); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); } else { @@ -1047,7 +1050,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // The worker threads. This must be last to ensure the // threads have exited before any other members are deallocated. // TODO(b/65178177): Avoid allocating additional threads. - std::vector> worker_threads_ GUARDED_BY(mu_); + std::vector> worker_threads_ + GUARDED_BY(mu_); }; const DatasetBase* const input_; @@ -1389,9 +1393,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); - runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", - [this, new_ctx]() { RunnerThread(new_ctx); })); + runner_thread_ = + MakeUnique(ctx->env(), "runner_thread"); + runner_thread_->Schedule( + [this, new_ctx]() { RunnerThread(new_ctx); }); } } @@ -1645,7 +1650,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr thread_pool_; - std::unique_ptr runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr runner_thread_ GUARDED_BY(*mu_); // Identifies whether background activity should be cancelled. bool cancelled_ GUARDED_BY(*mu_) = false; diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 13bd4b6036..626e98af91 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -180,9 +181,10 @@ class ParallelMapIterator : public DatasetBaseIterator { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr ctx_copy(new IteratorContext(*ctx)); - runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", - std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); + runner_thread_ = + MakeUnique(ctx->env(), "runner_thread"); + runner_thread_->Schedule( + std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)); } } @@ -330,7 +332,7 @@ class ParallelMapIterator : public DatasetBaseIterator { // Buffer for storing the invocation results. std::deque> invocation_results_ GUARDED_BY(*mu_); - std::unique_ptr runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr runner_thread_ GUARDED_BY(*mu_); bool cancelled_ GUARDED_BY(*mu_) = false; }; diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 754ed772db..e9c38eb8a0 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -256,10 +257,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status EnsurePrefetchThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!prefetch_thread_) { + prefetch_thread_ = + MakeUnique(ctx->env(), "prefetch_thread"); std::shared_ptr new_ctx(new IteratorContext(*ctx)); - prefetch_thread_.reset(ctx->env()->StartThread( - {}, "prefetch_thread", - [this, new_ctx]() { PrefetchThread(new_ctx); })); + prefetch_thread_->Schedule( + [this, new_ctx]() { PrefetchThread(new_ctx); }); } return Status::OK(); } @@ -363,7 +365,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { string prefix_end_; PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); - std::unique_ptr prefetch_thread_ GUARDED_BY(mu_); + std::unique_ptr prefetch_thread_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_) = false; bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; }; diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc index 3f76695bb1..7bb2077b62 100644 --- a/tensorflow/core/kernels/data/writer_ops.cc +++ b/tensorflow/core/kernels/data/writer_ops.cc @@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel { public: explicit ToTFRecordOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - thread_pool_(new thread::ThreadPool( - ctx->env(), ThreadOptions(), - strings::StrCat("to_tf_record__op_", SanitizeThreadSuffix(name())), - 1 /* num_threads */, false /* low_latency_hint */)) {} + background_worker_( + ctx->env(), + strings::StrCat("to_tf_record_op_", SanitizeThreadSuffix(name()))) { + } template Status ParseScalarArgument(OpKernelContext* ctx, @@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel { // The call to `iterator->GetNext()` may block and depend on an // inter-op thread pool thread, so we issue the call from the // owned thread pool. - thread_pool_->Schedule([this, ctx, done]() { + background_worker_.Schedule([this, ctx, done]() { string filename; OP_REQUIRES_OK_ASYNC( ctx, ParseScalarArgument(ctx, "filename", &filename), done); @@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel { } private: - std::unique_ptr thread_pool_; + BackgroundWorker background_worker_; }; REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU), -- cgit v1.2.3