aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc34
1 files changed, 19 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index 5f143967d9..d909b9e9d3 100644
--- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -134,19 +134,17 @@ class MultiDeviceIterator : public ResourceBase {
void Reset() LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
- if (background_thread_finished_) {
- return;
- }
-
- cancelled_ = true;
- // Wake up the background thread.
- for (int i = 0; i < size_; ++i) {
- buffer_[i].cond_var.notify_all();
- }
+ if (!background_thread_finished_) {
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
- // Make sure background thread has finished first.
- while (!background_thread_finished_) {
- shutdown_cond_var_.wait(l);
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
}
}
RunPendingCallbacks();
@@ -182,7 +180,7 @@ class MultiDeviceIterator : public ResourceBase {
buffer_[shard_num].cond_var.notify_all();
}
} else {
- if (background_thread_finished_) {
+ if (end_of_iterator_) {
produced_output = true;
elem.end_of_sequence = true;
} else {
@@ -219,8 +217,12 @@ class MultiDeviceIterator : public ResourceBase {
while (!buffer_[i].callbacks.empty()) {
if (buffer_[i].data.empty()) {
HostBufferElement elem;
- elem.status =
- errors::Cancelled("Cancelled and buffer not filled.");
+ if (end_of_iterator_) {
+ elem.end_of_sequence = true;
+ } else {
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ }
cancellation_elements.push_back(std::move(elem));
} else {
cancellation_elements.push_back(
@@ -293,6 +295,7 @@ class MultiDeviceIterator : public ResourceBase {
{
mutex_lock l(mu_);
background_thread_finished_ = true;
+ end_of_iterator_ = true;
shutdown_cond_var_.notify_all();
}
RunPendingCallbacks();
@@ -312,6 +315,7 @@ class MultiDeviceIterator : public ResourceBase {
std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
bool background_thread_finished_ GUARDED_BY(mu_) = false;
bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool end_of_iterator_ GUARDED_BY(mu_) = false;
bool cancelled_ GUARDED_BY(mu_) = false;
condition_variable shutdown_cond_var_ GUARDED_BY(mu_);