aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-08 17:37:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 17:44:23 -0700
commit4ff7b81514ea1b86295bc74b620e3c1d3e127e6f (patch)
treebfa49352d5ff057222784de6cba050548c6c3a03
parent03d097bc96080981098ffdbaf1b3465e6e153a6a (diff)
Fix the seeding for `Dataset.shuffle(..., reshuffle_each_iteration=False)`.
Previously, we were passing the first (graph-level) seed for both the graph-level and op-level seeds when creating a C++ dataset. This change passes the op-level seed to the appropriate point, and adds a test for the behavior with graph-but-not-op-level seeds. PiperOrigin-RevId: 216280641
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc2
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py35
3 files changed, 38 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 66466d6a36..9f54c381a9 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -485,7 +485,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
int64 buffer_size, int64 seed, int64 seed2, int64 count)
: ShuffleDatasetBase(ctx, input, buffer_size, count),
seed_(seed),
- seed2_(seed) {}
+ seed2_(seed2) {}
string DebugString() const override {
return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index c7295d6e69..671b7ca1bb 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -443,12 +443,15 @@ tf_py_test(
srcs = ["shuffle_dataset_op_test.py"],
additional_deps = [
":test_base",
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_seed",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index 347af18576..8694f58a24 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import collections
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
@@ -27,11 +28,13 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ShuffleDatasetTest(test_base.DatasetTestBase):
+class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testShuffleDataset(self):
components = (
@@ -209,5 +212,35 @@ class ShuffleDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ @parameterized.named_parameters(
+ ("ReshuffleGraphLevelSeed", True, 38, None),
+ ("ReshuffleOpLevelSeed", True, None, 42),
+ ("ReshuffleGraphAndOpLevelSeed", True, 38, 42),
+ ("NoReshuffleGraphLevelSeed", False, 38, None),
+ ("NoReshuffleOpLevelSeed", False, None, 42),
+ ("NoReshuffleGraphAndOpLevelSeed", False, 38, 42),
+ )
+ def testShuffleSeed(self, reshuffle, graph_level_seed, op_level_seed):
+ results = []
+ for _ in range(2):
+ with ops.Graph().as_default() as g:
+ random_seed.set_random_seed(graph_level_seed)
+ dataset = dataset_ops.Dataset.range(10).shuffle(
+ 10, seed=op_level_seed, reshuffle_each_iteration=reshuffle).repeat(
+ 3)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ run_results = []
+ with self.session(graph=g) as sess:
+ for _ in range(30):
+ run_results.append(sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ results.append(run_results)
+
+ self.assertAllEqual(results[0], results[1])
+
+
if __name__ == "__main__":
test.main()