aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc38
1 files changed, 29 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 640f1565b7..aa5e613e24 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -252,6 +252,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ SetMetadata(ctx, "parallelism", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -351,11 +352,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (must_wait_for_input) {
// Wait for elements to become available.
+ StopWork(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
+ StartWork(ctx);
}
}
return errors::Cancelled(
@@ -484,10 +487,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (reader->Contains(full_name("worker_threads_running"))) {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
return Status::OK();
@@ -583,10 +586,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
workers_[i].SetInputs(s, std::move(args));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -601,7 +604,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
// Produces elements into the worker's output buffers.
- void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
+ void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+ const int64 thread_index) {
// Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
//
// 1. Any local state that may need to be checkpointed should be kept
@@ -622,10 +626,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
- std::unique_ptr<IteratorContext> ctx(ctx_ptr);
- auto cleanup = gtl::MakeCleanup([this, thread_index] {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
+ StopWork(ctx.get());
});
bool make_new_iterator;
{
@@ -651,9 +656,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// 1. Build a new iterator or use the existing one.
if (make_new_iterator) {
// 1a. Get new input tensors or use the exiting ones.
-
bool read_new_input;
-
{
tf_shared_lock l(ckpt_mu_);
// worker_thread_states_[thread_index].input will be non-empty
@@ -665,7 +668,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
+ StopWork(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
@@ -715,7 +720,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ StopWork(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
@@ -764,7 +771,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ StopWork(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) return;
@@ -1241,6 +1250,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -1256,7 +1266,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
+ StopWork(ctx);
cond_var_.wait(l);
+ StartWork(ctx);
}
if (!invocation_results_.empty()) {
std::swap(result, invocation_results_.front());
@@ -1267,7 +1279,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
}
cond_var_.notify_all();
+ StopWork(ctx);
result->notification.WaitForNotification();
+ StartWork(ctx);
} while (result->skip);
if (result->status.ok()) {
@@ -1391,6 +1405,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
LOCKS_EXCLUDED(mu_) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
bool end_of_input = false;
for (auto& result : results) {
if (!end_of_input) {
@@ -1433,6 +1449,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
//
// This method runs in the `runner_thread` background thread.
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
while (true) {
{
mutex_lock l(mu_);
@@ -1443,7 +1461,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
(element_in_use_[cycle_index_] ||
num_calls_ >= dataset()->num_parallel_calls_ ||
invocation_results_.size() >= MaxInvocationResults())) {
+ StopWork(ctx.get());
cond_var_.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_ || (end_of_input_ && num_open_ == 0)) {