aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py')
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index 8694f58a24..cad28f860e 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -241,6 +241,38 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual(results[0], results[1])
+ @parameterized.named_parameters(
+ ("ReshuffleOneShot", True, False),
+ ("ReshuffleInitializable", True, True),
+ ("NoReshuffleOneShot", False, False),
+ ("NoReshuffleInitializable", False, True),
+ )
+ def testMultipleIterators(self, reshuffle, initializable):
+ with ops.Graph().as_default() as g:
+ dataset = dataset_ops.Dataset.range(100).shuffle(
+ 10, reshuffle_each_iteration=reshuffle).repeat(3)
+
+ if initializable:
+ iterators = [dataset.make_initializable_iterator() for _ in range(2)]
+ else:
+ iterators = [dataset.make_one_shot_iterator() for _ in range(2)]
+
+ results = []
+ with self.session(graph=g) as sess:
+ for iterator in iterators:
+ if initializable:
+ sess.run(iterator.initializer)
+ next_element = iterator.get_next()
+ run_results = []
+ for _ in range(300):
+ run_results.append(sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ results.append(run_results)
+
+ self.assertNotEqual(results[0], results[1])
+
if __name__ == "__main__":
test.main()