diff options
Diffstat (limited to 'tensorflow/python/data/kernel_tests/shard_dataset_op_test.py')
-rw-r--r-- | tensorflow/python/data/kernel_tests/shard_dataset_op_test.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py index cefe872d0f..137f6341ce 100644 --- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py @@ -28,7 +28,7 @@ class ShardDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.range(10).shard(5, 2) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(2, sess.run(iterator.get_next())) self.assertEqual(7, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -40,7 +40,7 @@ class ShardDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual((2, 8), sess.run(iterator.get_next())) self.assertEqual((7, 3), sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -50,7 +50,7 @@ class ShardDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.range(10).shard(5, 0) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(0, sess.run(iterator.get_next())) self.assertEqual(5, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -76,14 +76,14 @@ class ShardDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.range(1).shard(5, 2) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) def testLargerWorkerPool(self): dataset = dataset_ops.Dataset.range(10).shard(7, 5) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(5, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) @@ -91,7 +91,7 @@ class ShardDatasetOpTest(test.TestCase): def testIndexEqualsNumShards(self): dataset = dataset_ops.Dataset.range(10).shard(5, 4) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(4, sess.run(iterator.get_next())) self.assertEqual(9, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -100,7 +100,7 @@ class ShardDatasetOpTest(test.TestCase): def testIndexEqualsNumShards2(self): dataset = dataset_ops.Dataset.range(10).shard(4, 3) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(3, sess.run(iterator.get_next())) self.assertEqual(7, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): |