From 4ff7b81514ea1b86295bc74b620e3c1d3e127e6f Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 8 Oct 2018 17:37:44 -0700 Subject: Fix the seeding for `Dataset.shuffle(..., reshuffle_each_iteration=False)`. Previously, we were passing the first (graph-level) seed for both the graph-level and op-level seeds when creating a C++ dataset. This change passes the op-level seed to the appropriate point, and adds a test for the behavior with graph-but-not-op-level seeds. PiperOrigin-RevId: 216280641 --- tensorflow/core/kernels/data/shuffle_dataset_op.cc | 2 +- tensorflow/python/data/kernel_tests/BUILD | 3 ++ .../data/kernel_tests/shuffle_dataset_op_test.py | 35 +++++++++++++++++++++- 3 files changed, 38 insertions(+), 2 deletions(-) (limited to 'tensorflow') diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 66466d6a36..9f54c381a9 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -485,7 +485,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { int64 buffer_size, int64 seed, int64 seed2, int64 count) : ShuffleDatasetBase(ctx, input, buffer_size, count), seed_(seed), - seed2_(seed) {} + seed2_(seed2) {} string DebugString() const override { return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_, diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index c7295d6e69..671b7ca1bb 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -443,12 +443,15 @@ tf_py_test( srcs = ["shuffle_dataset_op_test.py"], additional_deps = [ ":test_base", + "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_seed", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", ], diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py index 347af18576..8694f58a24 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import collections +from absl.testing import parameterized import numpy as np from tensorflow.python.data.kernel_tests import test_base @@ -27,11 +28,13 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ShuffleDatasetTest(test_base.DatasetTestBase): +class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testShuffleDataset(self): components = ( @@ -209,5 +212,35 @@ class ShuffleDatasetTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + @parameterized.named_parameters( + ("ReshuffleGraphLevelSeed", True, 38, None), + ("ReshuffleOpLevelSeed", True, None, 42), + ("ReshuffleGraphAndOpLevelSeed", True, 38, 42), + ("NoReshuffleGraphLevelSeed", False, 38, None), + ("NoReshuffleOpLevelSeed", False, None, 42), + ("NoReshuffleGraphAndOpLevelSeed", False, 38, 42), + ) + def testShuffleSeed(self, reshuffle, graph_level_seed, op_level_seed): + results = [] + for _ in range(2): + with ops.Graph().as_default() as g: + random_seed.set_random_seed(graph_level_seed) + dataset = dataset_ops.Dataset.range(10).shuffle( + 10, seed=op_level_seed, reshuffle_each_iteration=reshuffle).repeat( + 3) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + run_results = [] + with self.session(graph=g) as sess: + for _ in range(30): + run_results.append(sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + results.append(run_results) + + self.assertAllEqual(results[0], results[1]) + + if __name__ == "__main__": test.main() -- cgit v1.2.3