aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-21 14:05:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 14:08:39 -0700
commitf4de7ec889311c42b3af4d5f34f7d31f56f73177 (patch)
tree5f92c0c2cdfad717c4c6216d9f4d2af7574def23 /tensorflow/contrib/data
parent3f40afa0409a2b22ff5a2e735418da7724aca0e8 (diff)
Fixes a bug for the case when the MultiDeviceIterator waits on background
thread to finish even if None is running. PiperOrigin-RevId: 214040824
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py9
2 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 078de717e0..39f23f7b24 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -621,7 +621,13 @@ class MultiDeviceIterator : public ResourceBase {
incarnation_id_(incarnation_id),
host_iterator_(std::move(host_iterator)) {}
- ~MultiDeviceBuffer() { Reset(); }
+ ~MultiDeviceBuffer() {
+ {
+ mutex_lock l(mu_);
+ if (!background_thread_started_) return;
+ }
+ Reset();
+ }
void Reset() LOCKS_EXCLUDED(mu_) {
{
@@ -731,6 +737,10 @@ class MultiDeviceIterator : public ResourceBase {
}
void BackgroundThread(IteratorContext* ctx) {
+ {
+ mutex_lock l(mu_);
+ background_thread_started_ = true;
+ }
std::unique_ptr<IteratorContext> cleanup(ctx);
int shard_to_fetch = 0;
while (true) {
@@ -799,6 +809,7 @@ class MultiDeviceIterator : public ResourceBase {
mutex mu_;
std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
bool background_thread_finished_ GUARDED_BY(mu_) = false;
+ bool background_thread_started_ GUARDED_BY(mu_) = false;
bool cancelled_ GUARDED_BY(mu_) = false;
condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 0166ba0d44..5b17511e41 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -946,6 +946,15 @@ class CopyToDeviceTest(test.TestCase):
class MultiDeviceIteratorTest(test.TestCase):
+ def testNoGetNext(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = prefetching_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+
def testBasic(self):
dataset = dataset_ops.Dataset.range(10)
multi_device_iterator = prefetching_ops.MultiDeviceIterator(