aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/ops/dataset_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/ops/dataset_ops.py')
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6195747671..cdb883cac9 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed as core_random_seed
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
@@ -178,10 +179,21 @@ class Dataset(object):
"""
if context.executing_eagerly():
return iterator_ops.EagerIterator(self)
+
+ graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
+
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
# a 0-argument function.
@function.Defun(capture_by_value=True)
def _make_dataset():
+ # NOTE(mrry): `Defun` does not capture the graph-level seed from the
+ # enclosing graph, so if a graph-level seed is present we set the local
+ # graph seed based on a combination of the graph- and op-level seeds.
+ if graph_level_seed is not None:
+ assert op_level_seed is not None
+ core_random_seed.set_random_seed(
+ (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
+
dataset = self
options = self.options()
static_optimizations = options._static_optimizations() # pylint: disable=protected-access
@@ -2265,6 +2277,7 @@ class ShuffleDataset(UnaryDataset):
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
self._seed, self._seed2 = random_seed.get_seed(seed)
+
if reshuffle_each_iteration is None:
self._reshuffle_each_iteration = True
else: