diff options
author | Derek Murray <mrry@google.com> | 2018-10-10 08:12:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-10 08:16:27 -0700 |
commit | 1ae0a45a5de65ab4ae6def232da016e7ee32773c (patch) | |
tree | 1169c6709ed2ae4b99d21f37f5435c6ac82dc978 | |
parent | 0bb68afa38cf5c45232e85fb09186e01055e4d11 (diff) |
[tf.data] `Dataset.make_one_shot_iterator()` inherits the random seed from the calling graph.
This change makes a subtle difference to the behavior of existing
programs that create multiple iterators. Previously, one-shot
iterators would not inherit the graph seed, and so their values would
be non-deterministic (unless explicit seeds were set). After this
change, an iterator will inherit its seed from the outer
graph. Multiple one-shot iterators created from the same dataset will
inherit different seeds, matching the semantics of creating multiple
ops with the same graph seed.
PiperOrigin-RevId: 216532256
-rw-r--r-- | tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py | 32 | ||||
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 13 |
2 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py index 8694f58a24..cad28f860e 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py @@ -241,6 +241,38 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual(results[0], results[1]) + @parameterized.named_parameters( + ("ReshuffleOneShot", True, False), + ("ReshuffleInitializable", True, True), + ("NoReshuffleOneShot", False, False), + ("NoReshuffleInitializable", False, True), + ) + def testMultipleIterators(self, reshuffle, initializable): + with ops.Graph().as_default() as g: + dataset = dataset_ops.Dataset.range(100).shuffle( + 10, reshuffle_each_iteration=reshuffle).repeat(3) + + if initializable: + iterators = [dataset.make_initializable_iterator() for _ in range(2)] + else: + iterators = [dataset.make_one_shot_iterator() for _ in range(2)] + + results = [] + with self.session(graph=g) as sess: + for iterator in iterators: + if initializable: + sess.run(iterator.initializer) + next_element = iterator.get_next() + run_results = [] + for _ in range(300): + run_results.append(sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + results.append(run_results) + + self.assertNotEqual(results[0], results[1]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 6195747671..cdb883cac9 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed as core_random_seed from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape @@ -178,10 +179,21 @@ class Dataset(object): """ if context.executing_eagerly(): return iterator_ops.EagerIterator(self) + + graph_level_seed, op_level_seed = core_random_seed.get_seed(None) + # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is # a 0-argument function. @function.Defun(capture_by_value=True) def _make_dataset(): + # NOTE(mrry): `Defun` does not capture the graph-level seed from the + # enclosing graph, so if a graph-level seed is present we set the local + # graph seed based on a combination of the graph- and op-level seeds. + if graph_level_seed is not None: + assert op_level_seed is not None + core_random_seed.set_random_seed( + (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) + dataset = self options = self.options() static_optimizations = options._static_optimizations() # pylint: disable=protected-access @@ -2265,6 +2277,7 @@ class ShuffleDataset(UnaryDataset): self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") self._seed, self._seed2 = random_seed.get_seed(seed) + if reshuffle_each_iteration is None: self._reshuffle_each_iteration = True else: |