aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-10 08:12:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 08:16:27 -0700
commit1ae0a45a5de65ab4ae6def232da016e7ee32773c (patch)
tree1169c6709ed2ae4b99d21f37f5435c6ac82dc978
parent0bb68afa38cf5c45232e85fb09186e01055e4d11 (diff)
[tf.data] `Dataset.make_one_shot_iterator()` inherits the random seed from the calling graph.
This change makes a subtle difference to the behavior of existing programs that create multiple iterators. Previously, one-shot iterators would not inherit the graph seed, and so their values would be non-deterministic (unless explicit seeds were set). After this change, an iterator will inherit its seed from the outer graph. Multiple one-shot iterators created from the same dataset will inherit different seeds, matching the semantics of creating multiple ops with the same graph seed. PiperOrigin-RevId: 216532256
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py32
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py13
2 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index 8694f58a24..cad28f860e 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -241,6 +241,38 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual(results[0], results[1])
+ @parameterized.named_parameters(
+ ("ReshuffleOneShot", True, False),
+ ("ReshuffleInitializable", True, True),
+ ("NoReshuffleOneShot", False, False),
+ ("NoReshuffleInitializable", False, True),
+ )
+ def testMultipleIterators(self, reshuffle, initializable):
+ with ops.Graph().as_default() as g:
+ dataset = dataset_ops.Dataset.range(100).shuffle(
+ 10, reshuffle_each_iteration=reshuffle).repeat(3)
+
+ if initializable:
+ iterators = [dataset.make_initializable_iterator() for _ in range(2)]
+ else:
+ iterators = [dataset.make_one_shot_iterator() for _ in range(2)]
+
+ results = []
+ with self.session(graph=g) as sess:
+ for iterator in iterators:
+ if initializable:
+ sess.run(iterator.initializer)
+ next_element = iterator.get_next()
+ run_results = []
+ for _ in range(300):
+ run_results.append(sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ results.append(run_results)
+
+ self.assertNotEqual(results[0], results[1])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6195747671..cdb883cac9 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed as core_random_seed
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
@@ -178,10 +179,21 @@ class Dataset(object):
"""
if context.executing_eagerly():
return iterator_ops.EagerIterator(self)
+
+ graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
+
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
# a 0-argument function.
@function.Defun(capture_by_value=True)
def _make_dataset():
+ # NOTE(mrry): `Defun` does not capture the graph-level seed from the
+ # enclosing graph, so if a graph-level seed is present we set the local
+ # graph seed based on a combination of the graph- and op-level seeds.
+ if graph_level_seed is not None:
+ assert op_level_seed is not None
+ core_random_seed.set_random_seed(
+ (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
+
dataset = self
options = self.options()
static_optimizations = options._static_optimizations() # pylint: disable=protected-access
@@ -2265,6 +2277,7 @@ class ShuffleDataset(UnaryDataset):
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
self._seed, self._seed2 = random_seed.get_seed(seed)
+
if reshuffle_each_iteration is None:
self._reshuffle_each_iteration = True
else: