aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/shard_dataset_op_test.py')
-rw-r--r--tensorflow/python/data/kernel_tests/shard_dataset_op_test.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
index cefe872d0f..137f6341ce 100644
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
@@ -28,7 +28,7 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(2, sess.run(iterator.get_next()))
self.assertEqual(7, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -40,7 +40,7 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual((2, 8), sess.run(iterator.get_next()))
self.assertEqual((7, 3), sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -50,7 +50,7 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.range(10).shard(5, 0)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(iterator.get_next()))
self.assertEqual(5, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -76,14 +76,14 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.range(1).shard(5, 2)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
def testLargerWorkerPool(self):
dataset = dataset_ops.Dataset.range(10).shard(7, 5)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(5, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
@@ -91,7 +91,7 @@ class ShardDatasetOpTest(test.TestCase):
def testIndexEqualsNumShards(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 4)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(4, sess.run(iterator.get_next()))
self.assertEqual(9, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -100,7 +100,7 @@ class ShardDatasetOpTest(test.TestCase):
def testIndexEqualsNumShards2(self):
dataset = dataset_ops.Dataset.range(10).shard(4, 3)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(iterator.get_next()))
self.assertEqual(7, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):