diff options
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc | 38 |
1 files changed, 29 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 640f1565b7..aa5e613e24 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -252,6 +252,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { + SetMetadata(ctx, "parallelism", dataset()->cycle_length_); TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -351,11 +352,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (must_wait_for_input) { // Wait for elements to become available. + StopWork(ctx); if (dataset()->sloppy_) { sloppy_cond_var_.wait(l); } else { workers_[interleave_indices_[next_index_]].cond_var.wait(l); } + StartWork(ctx); } } return errors::Cancelled( @@ -484,10 +487,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (reader->Contains(full_name("worker_threads_running"))) { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); worker_threads_.emplace_back(ctx->env()->StartThread( {}, "worker_thread", - std::bind(&Iterator::WorkerThread, this, - new IteratorContext(*ctx), i))); + [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); } } return Status::OK(); @@ -583,10 +586,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } workers_[i].SetInputs(s, std::move(args)); + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); worker_threads_.emplace_back(ctx->env()->StartThread( {}, "worker_thread", - std::bind(&Iterator::WorkerThread, this, - new IteratorContext(*ctx), i))); + [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); } else { @@ -601,7 +604,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } // Produces elements into the worker's output buffers. - void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) { + void WorkerThread(const std::shared_ptr<IteratorContext>& ctx, + const int64 thread_index) { // Notes on checkpointing thread local state, i.e., `WorkerThreadState`: // // 1. Any local state that may need to be checkpointed should be kept @@ -622,10 +626,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // std::function arguments are copy-constructable, so we pass raw // pointers, and then immediately wrap them to ensure correct ownership. - std::unique_ptr<IteratorContext> ctx(ctx_ptr); - auto cleanup = gtl::MakeCleanup([this, thread_index] { + StartWork(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] { mutex_lock l(mu_); workers_[thread_index].cond_var.notify_all(); + StopWork(ctx.get()); }); bool make_new_iterator; { @@ -651,9 +656,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // 1. Build a new iterator or use the existing one. if (make_new_iterator) { // 1a. Get new input tensors or use the exiting ones. - bool read_new_input; - { tf_shared_lock l(ckpt_mu_); // worker_thread_states_[thread_index].input will be non-empty @@ -665,7 +668,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (read_new_input) { mutex_lock l(mu_); while (!cancelled_ && !workers_[thread_index].is_producing) { + StopWork(ctx.get()); workers_[thread_index].cond_var.wait(l); + StartWork(ctx.get()); } if (cancelled_) return; // Copy the input tensors so that we do not need to block on `mu_` @@ -715,7 +720,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Wait for space in the prefetch queue. while (!cancelled_ && workers_[thread_index].outputs.size() == dataset()->buffer_output_elements_) { + StopWork(ctx.get()); workers_[thread_index].cond_var.wait(l); + StartWork(ctx.get()); } if (cancelled_) return; tf_shared_lock ckpt_l(ckpt_mu_); @@ -764,7 +771,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Wait for space in the prefetch queue. while (!cancelled_ && workers_[thread_index].outputs.size() == dataset()->buffer_output_elements_) { + StopWork(ctx.get()); workers_[thread_index].cond_var.wait(l); + StartWork(ctx.get()); } if (cancelled_) return; @@ -1241,6 +1250,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { + SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_); TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -1256,7 +1266,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { EnsureRunnerThreadStarted(ctx); while (invocation_results_.empty() && (!end_of_input_ || num_open_ > 0)) { + StopWork(ctx); cond_var_.wait(l); + StartWork(ctx); } if (!invocation_results_.empty()) { std::swap(result, invocation_results_.front()); @@ -1267,7 +1279,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } } cond_var_.notify_all(); + StopWork(ctx); result->notification.WaitForNotification(); + StartWork(ctx); } while (result->skip); if (result->status.ok()) { @@ -1391,6 +1405,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index, const std::vector<std::shared_ptr<InvocationResult>>& results) LOCKS_EXCLUDED(mu_) { + StartWork(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); }); bool end_of_input = false; for (auto& result : results) { if (!end_of_input) { @@ -1433,6 +1449,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // // This method runs in the `runner_thread` background thread. void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + StartWork(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); }); while (true) { { mutex_lock l(mu_); @@ -1443,7 +1461,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { (element_in_use_[cycle_index_] || num_calls_ >= dataset()->num_parallel_calls_ || invocation_results_.size() >= MaxInvocationResults())) { + StopWork(ctx.get()); cond_var_.wait(l); + StartWork(ctx.get()); } if (cancelled_ || (end_of_input_ && num_open_ == 0)) { |