aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py')
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py26
1 files changed, 13 insertions, 13 deletions
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index 431362aa9a..aa3636364d 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -100,7 +100,7 @@ class TextLineDatasetTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from file 0.
sess.run(
init_op, feed_dict={filenames: [test_filenames[0]],
@@ -163,7 +163,7 @@ class TextLineDatasetTest(test.TestCase):
repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
iterator = repeat_dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(2):
for i in range(5):
self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next()))
@@ -240,7 +240,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from file 0.
sess.run(
init_op, feed_dict={filenames: [test_filenames[0]],
@@ -302,7 +302,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
buffer_size=10)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertEqual(self._record(j, i), sess.run(iterator.get_next()))
@@ -319,7 +319,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
buffer_size=10)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input "
@@ -661,7 +661,7 @@ class TFRecordDatasetTest(test.TestCase):
return filenames
def testReadOneEpoch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from file 0.
sess.run(
self.init_op,
@@ -698,7 +698,7 @@ class TFRecordDatasetTest(test.TestCase):
sess.run(self.get_next)
def testReadTenEpochs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={self.filenames: self.test_filenames,
@@ -711,7 +711,7 @@ class TFRecordDatasetTest(test.TestCase):
sess.run(self.get_next)
def testReadTenEpochsOfBatches(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_batch_op,
feed_dict={
@@ -738,7 +738,7 @@ class TFRecordDatasetTest(test.TestCase):
f.write(cdata)
zlib_files.append(zfn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={self.filenames: zlib_files,
@@ -758,7 +758,7 @@ class TFRecordDatasetTest(test.TestCase):
gzf.write(f.read())
gzip_files.append(gzfn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={self.filenames: gzip_files,
@@ -774,7 +774,7 @@ class TFRecordDatasetTest(test.TestCase):
d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte)
iterator = d.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(next_element))
@@ -786,7 +786,7 @@ class TFRecordDatasetTest(test.TestCase):
d = readers.TFRecordDataset(files)
iterator = d.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(next_element))
@@ -801,7 +801,7 @@ class TFRecordDatasetTest(test.TestCase):
next_element = iterator.get_next()
expected = []
actual = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
for j in range(self._num_files):
for i in range(self._num_records):