aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py')
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index 5fcc48831f..f294840706 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -60,7 +60,7 @@ class ShuffleDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# First run without shuffling to collect the "ground truth".
sess.run(init_fifo_op)
unshuffled_elements = []
@@ -140,7 +140,7 @@ class ShuffleDatasetTest(test.TestCase):
get_next = iterator.get_next()
elems = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
elems.append(sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@@ -152,7 +152,7 @@ class ShuffleDatasetTest(test.TestCase):
.make_initializable_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
for elem in elems:
self.assertEqual(elem, sess.run(get_next))
@@ -166,7 +166,7 @@ class ShuffleDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
counts = collections.defaultdict(lambda: 0)
for _ in range(10):
for _ in range(5):
@@ -183,7 +183,7 @@ class ShuffleDatasetTest(test.TestCase):
.make_one_shot_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initial_permutation = sess.run(next_element)
self.assertAllEqual(initial_permutation, sess.run(next_element))
self.assertAllEqual(initial_permutation, sess.run(next_element))
@@ -198,7 +198,7 @@ class ShuffleDatasetTest(test.TestCase):
.make_one_shot_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initial_permutation = list(sess.run(next_element))
for _ in range(2):
next_permutation = list(sess.run(next_element))