diff options
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 6b6b3d6ab9..9c836b836e 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -481,9 +482,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, "worker_thread", - [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); + worker_threads_.emplace_back( + MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread")); + worker_threads_.back()->Schedule( + [this, new_ctx, i]() { WorkerThread(new_ctx, i); }); } } return Status::OK(); @@ -580,9 +582,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } workers_[i].SetInputs(s, std::move(args)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back(ctx->env()->StartThread( - {}, "worker_thread", - [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); + worker_threads_.emplace_back( + MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread")); + worker_threads_.back()->Schedule( + [this, new_ctx, i]() { WorkerThread(new_ctx, i); }); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); } else { @@ -1047,7 +1050,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // The worker threads. This must be last to ensure the // threads have exited before any other members are deallocated. // TODO(b/65178177): Avoid allocating additional threads. - std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_); + std::vector<std::unique_ptr<BackgroundWorker>> worker_threads_ + GUARDED_BY(mu_); }; const DatasetBase* const input_; @@ -1389,9 +1393,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); - runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", - [this, new_ctx]() { RunnerThread(new_ctx); })); + runner_thread_ = + MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread"); + runner_thread_->Schedule( + [this, new_ctx]() { RunnerThread(new_ctx); }); } } @@ -1645,7 +1650,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr<thread::ThreadPool> thread_pool_; - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_); // Identifies whether background activity should be cancelled. bool cancelled_ GUARDED_BY(*mu_) = false; |