diff options
Diffstat (limited to 'tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py')
-rw-r--r-- | tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py index 431362aa9a..aa3636364d 100644 --- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py @@ -100,7 +100,7 @@ class TextLineDatasetTest(test.TestCase): init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Basic test: read from file 0. sess.run( init_op, feed_dict={filenames: [test_filenames[0]], @@ -163,7 +163,7 @@ class TextLineDatasetTest(test.TestCase): repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10) iterator = repeat_dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: for j in range(2): for i in range(5): self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next())) @@ -240,7 +240,7 @@ class FixedLengthRecordReaderTest(test.TestCase): init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Basic test: read from file 0. sess.run( init_op, feed_dict={filenames: [test_filenames[0]], @@ -302,7 +302,7 @@ class FixedLengthRecordReaderTest(test.TestCase): buffer_size=10) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertEqual(self._record(j, i), sess.run(iterator.get_next())) @@ -319,7 +319,7 @@ class FixedLengthRecordReaderTest(test.TestCase): buffer_size=10) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp( errors.InvalidArgumentError, r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input " @@ -661,7 +661,7 @@ class TFRecordDatasetTest(test.TestCase): return filenames def testReadOneEpoch(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Basic test: read from file 0. sess.run( self.init_op, @@ -698,7 +698,7 @@ class TFRecordDatasetTest(test.TestCase): sess.run(self.get_next) def testReadTenEpochs(self): - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={self.filenames: self.test_filenames, @@ -711,7 +711,7 @@ class TFRecordDatasetTest(test.TestCase): sess.run(self.get_next) def testReadTenEpochsOfBatches(self): - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_batch_op, feed_dict={ @@ -738,7 +738,7 @@ class TFRecordDatasetTest(test.TestCase): f.write(cdata) zlib_files.append(zfn) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={self.filenames: zlib_files, @@ -758,7 +758,7 @@ class TFRecordDatasetTest(test.TestCase): gzf.write(f.read()) gzip_files.append(gzfn) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={self.filenames: gzip_files, @@ -774,7 +774,7 @@ class TFRecordDatasetTest(test.TestCase): d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) iterator = d.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertAllEqual(self._record(j, i), sess.run(next_element)) @@ -786,7 +786,7 @@ class TFRecordDatasetTest(test.TestCase): d = readers.TFRecordDataset(files) iterator = d.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertAllEqual(self._record(j, i), sess.run(next_element)) @@ -801,7 +801,7 @@ class TFRecordDatasetTest(test.TestCase): next_element = iterator.get_next() expected = [] actual = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(10): for j in range(self._num_files): for i in range(self._num_records): |