diff options
Diffstat (limited to 'tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py')
-rw-r--r-- | tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py index 8694f58a24..cad28f860e 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py @@ -241,6 +241,38 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual(results[0], results[1]) + @parameterized.named_parameters( + ("ReshuffleOneShot", True, False), + ("ReshuffleInitializable", True, True), + ("NoReshuffleOneShot", False, False), + ("NoReshuffleInitializable", False, True), + ) + def testMultipleIterators(self, reshuffle, initializable): + with ops.Graph().as_default() as g: + dataset = dataset_ops.Dataset.range(100).shuffle( + 10, reshuffle_each_iteration=reshuffle).repeat(3) + + if initializable: + iterators = [dataset.make_initializable_iterator() for _ in range(2)] + else: + iterators = [dataset.make_one_shot_iterator() for _ in range(2)] + + results = [] + with self.session(graph=g) as sess: + for iterator in iterators: + if initializable: + sess.run(iterator.initializer) + next_element = iterator.get_next() + run_results = [] + for _ in range(300): + run_results.append(sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + results.append(run_results) + + self.assertNotEqual(results[0], results[1]) + if __name__ == "__main__": test.main() |