From 1ae0a45a5de65ab4ae6def232da016e7ee32773c Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 10 Oct 2018 08:12:24 -0700 Subject: [tf.data] `Dataset.make_one_shot_iterator()` inherits the random seed from the calling graph. This change makes a subtle difference to the behavior of existing programs that create multiple iterators. Previously, one-shot iterators would not inherit the graph seed, and so their values would be non-deterministic (unless explicit seeds were set). After this change, an iterator will inherit its seed from the outer graph. Multiple one-shot iterators created from the same dataset will inherit different seeds, matching the semantics of creating multiple ops with the same graph seed. PiperOrigin-RevId: 216532256 --- .../data/kernel_tests/shuffle_dataset_op_test.py | 32 ++++++++++++++++++++++ tensorflow/python/data/ops/dataset_ops.py | 13 +++++++++ 2 files changed, 45 insertions(+) diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py index 8694f58a24..cad28f860e 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py @@ -241,6 +241,38 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual(results[0], results[1]) + @parameterized.named_parameters( + ("ReshuffleOneShot", True, False), + ("ReshuffleInitializable", True, True), + ("NoReshuffleOneShot", False, False), + ("NoReshuffleInitializable", False, True), + ) + def testMultipleIterators(self, reshuffle, initializable): + with ops.Graph().as_default() as g: + dataset = dataset_ops.Dataset.range(100).shuffle( + 10, reshuffle_each_iteration=reshuffle).repeat(3) + + if initializable: + iterators = [dataset.make_initializable_iterator() for _ in range(2)] + else: + iterators = [dataset.make_one_shot_iterator() for _ in range(2)] + + results = [] + with self.session(graph=g) as sess: + for iterator in iterators: + if initializable: + sess.run(iterator.initializer) + next_element = iterator.get_next() + run_results = [] + for _ in range(300): + run_results.append(sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + results.append(run_results) + + self.assertNotEqual(results[0], results[1]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 6195747671..cdb883cac9 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed as core_random_seed from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape @@ -178,10 +179,21 @@ class Dataset(object): """ if context.executing_eagerly(): return iterator_ops.EagerIterator(self) + + graph_level_seed, op_level_seed = core_random_seed.get_seed(None) + # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is # a 0-argument function. @function.Defun(capture_by_value=True) def _make_dataset(): + # NOTE(mrry): `Defun` does not capture the graph-level seed from the + # enclosing graph, so if a graph-level seed is present we set the local + # graph seed based on a combination of the graph- and op-level seeds. + if graph_level_seed is not None: + assert op_level_seed is not None + core_random_seed.set_random_seed( + (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) + dataset = self options = self.options() static_optimizations = options._static_optimizations() # pylint: disable=protected-access @@ -2265,6 +2277,7 @@ class ShuffleDataset(UnaryDataset): self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") self._seed, self._seed2 = random_seed.get_seed(seed) + if reshuffle_each_iteration is None: self._reshuffle_each_iteration = True else: -- cgit v1.2.3