diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-21 14:05:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 14:08:39 -0700 |
commit | f4de7ec889311c42b3af4d5f34f7d31f56f73177 (patch) | |
tree | 5f92c0c2cdfad717c4c6216d9f4d2af7574def23 /tensorflow/contrib/data | |
parent | 3f40afa0409a2b22ff5a2e735418da7724aca0e8 (diff) |
Fixes a bug for the case when the MultiDeviceIterator waits on background
thread to finish even if None is running.
PiperOrigin-RevId: 214040824
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/kernels/prefetching_kernels.cc | 13 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py | 9 |
2 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 078de717e0..39f23f7b24 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -621,7 +621,13 @@ class MultiDeviceIterator : public ResourceBase { incarnation_id_(incarnation_id), host_iterator_(std::move(host_iterator)) {} - ~MultiDeviceBuffer() { Reset(); } + ~MultiDeviceBuffer() { + { + mutex_lock l(mu_); + if (!background_thread_started_) return; + } + Reset(); + } void Reset() LOCKS_EXCLUDED(mu_) { { @@ -731,6 +737,10 @@ class MultiDeviceIterator : public ResourceBase { } void BackgroundThread(IteratorContext* ctx) { + { + mutex_lock l(mu_); + background_thread_started_ = true; + } std::unique_ptr<IteratorContext> cleanup(ctx); int shard_to_fetch = 0; while (true) { @@ -799,6 +809,7 @@ class MultiDeviceIterator : public ResourceBase { mutex mu_; std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_); bool background_thread_finished_ GUARDED_BY(mu_) = false; + bool background_thread_started_ GUARDED_BY(mu_) = false; bool cancelled_ GUARDED_BY(mu_) = false; condition_variable shutdown_cond_var_ GUARDED_BY(mu_); diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 0166ba0d44..5b17511e41 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -946,6 +946,15 @@ class CopyToDeviceTest(test.TestCase): class MultiDeviceIteratorTest(test.TestCase): + def testNoGetNext(self): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"]) + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + def testBasic(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = prefetching_ops.MultiDeviceIterator( |