diff options
Diffstat (limited to 'tensorflow/python/data/ops/dataset_ops.py')
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 6195747671..cdb883cac9 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed as core_random_seed from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape @@ -178,10 +179,21 @@ class Dataset(object): """ if context.executing_eagerly(): return iterator_ops.EagerIterator(self) + + graph_level_seed, op_level_seed = core_random_seed.get_seed(None) + # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is # a 0-argument function. @function.Defun(capture_by_value=True) def _make_dataset(): + # NOTE(mrry): `Defun` does not capture the graph-level seed from the + # enclosing graph, so if a graph-level seed is present we set the local + # graph seed based on a combination of the graph- and op-level seeds. + if graph_level_seed is not None: + assert op_level_seed is not None + core_random_seed.set_random_seed( + (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) + dataset = self options = self.options() static_optimizations = options._static_optimizations() # pylint: disable=protected-access @@ -2265,6 +2277,7 @@ class ShuffleDataset(UnaryDataset): self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") self._seed, self._seed2 = random_seed.get_seed(seed) + if reshuffle_each_iteration is None: self._reshuffle_each_iteration = True else: |