diff options
Diffstat (limited to 'tensorflow/core/kernels/data/multi_device_iterator_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/data/multi_device_iterator_ops.cc | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index 5f143967d9..d909b9e9d3 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -134,19 +134,17 @@ class MultiDeviceIterator : public ResourceBase { void Reset() LOCKS_EXCLUDED(mu_) { { mutex_lock l(mu_); - if (background_thread_finished_) { - return; - } - - cancelled_ = true; - // Wake up the background thread. - for (int i = 0; i < size_; ++i) { - buffer_[i].cond_var.notify_all(); - } + if (!background_thread_finished_) { + cancelled_ = true; + // Wake up the background thread. + for (int i = 0; i < size_; ++i) { + buffer_[i].cond_var.notify_all(); + } - // Make sure background thread has finished first. - while (!background_thread_finished_) { - shutdown_cond_var_.wait(l); + // Make sure background thread has finished first. + while (!background_thread_finished_) { + shutdown_cond_var_.wait(l); + } } } RunPendingCallbacks(); @@ -182,7 +180,7 @@ class MultiDeviceIterator : public ResourceBase { buffer_[shard_num].cond_var.notify_all(); } } else { - if (background_thread_finished_) { + if (end_of_iterator_) { produced_output = true; elem.end_of_sequence = true; } else { @@ -219,8 +217,12 @@ class MultiDeviceIterator : public ResourceBase { while (!buffer_[i].callbacks.empty()) { if (buffer_[i].data.empty()) { HostBufferElement elem; - elem.status = - errors::Cancelled("Cancelled and buffer not filled."); + if (end_of_iterator_) { + elem.end_of_sequence = true; + } else { + elem.status = + errors::Cancelled("Cancelled and buffer not filled."); + } cancellation_elements.push_back(std::move(elem)); } else { cancellation_elements.push_back( @@ -293,6 +295,7 @@ class MultiDeviceIterator : public ResourceBase { { mutex_lock l(mu_); background_thread_finished_ = true; + end_of_iterator_ = true; shutdown_cond_var_.notify_all(); } RunPendingCallbacks(); @@ -312,6 +315,7 @@ class MultiDeviceIterator : public ResourceBase { std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_); bool background_thread_finished_ GUARDED_BY(mu_) = false; bool background_thread_started_ GUARDED_BY(mu_) = false; + bool end_of_iterator_ GUARDED_BY(mu_) = false; bool cancelled_ GUARDED_BY(mu_) = false; condition_variable shutdown_cond_var_ GUARDED_BY(mu_); |