diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/record_input_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/record_input_test.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py index 068860d5d4..ebb9872f22 100644 --- a/tensorflow/python/kernel_tests/record_input_test.py +++ b/tensorflow/python/kernel_tests/record_input_test.py @@ -44,7 +44,7 @@ class RecordInputOpTest(test.TestCase): w.close() def testRecordInputSimple(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.generateTestData("basic", 1, 1) yield_op = data_flow_ops.RecordInput( @@ -57,7 +57,7 @@ class RecordInputOpTest(test.TestCase): self.assertEqual(sess.run(yield_op), b"0000000000") def testRecordInputSimpleGzip(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.generateTestData( "basic", 1, @@ -76,7 +76,7 @@ class RecordInputOpTest(test.TestCase): self.assertEqual(sess.run(yield_op), b"0000000000") def testRecordInputSimpleZlib(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.generateTestData( "basic", 1, @@ -98,7 +98,7 @@ class RecordInputOpTest(test.TestCase): files = 100 records_per_file = 100 batches = 2 - with self.test_session() as sess: + with self.cached_session() as sess: self.generateTestData("basic", files, records_per_file) records = data_flow_ops.RecordInput( @@ -126,7 +126,7 @@ class RecordInputOpTest(test.TestCase): def testDoesNotDeadlock(self): # Iterate multiple times to cause deadlock if there is a chance it can occur for _ in range(30): - with self.test_session() as sess: + with self.cached_session() as sess: self.generateTestData("basic", 1, 1) records = data_flow_ops.RecordInput( @@ -141,7 +141,7 @@ class RecordInputOpTest(test.TestCase): sess.run(yield_op) def testEmptyGlob(self): - with self.test_session() as sess: + with self.cached_session() as sess: record_input = data_flow_ops.RecordInput(file_pattern="foo") yield_op = record_input.get_yield_op() sess.run(variables.global_variables_initializer()) @@ -152,7 +152,7 @@ class RecordInputOpTest(test.TestCase): files = 10 records_per_file = 10 batches = 2 - with self.test_session() as sess: + with self.cached_session() as sess: self.generateTestData("basic", files, records_per_file) records = data_flow_ops.RecordInput( |