diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/io_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/io_ops_test.py | 37 |
1 files changed, 25 insertions, 12 deletions
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py index d484a609fc..b0c46ea07d 100644 --- a/tensorflow/python/kernel_tests/io_ops_test.py +++ b/tensorflow/python/kernel_tests/io_ops_test.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +20,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import tempfile import tensorflow as tf @@ -31,25 +32,31 @@ class IoOpsTest(tf.test.TestCase): cases = ['', 'Some contents', 'Неки садржаји на српском'] for contents in cases: contents = tf.compat.as_bytes(contents) - temp = tempfile.NamedTemporaryFile( - prefix='ReadFileTest', dir=self.get_temp_dir()) - open(temp.name, 'wb').write(contents) + with tempfile.NamedTemporaryFile(prefix='ReadFileTest', + dir=self.get_temp_dir(), + delete=False) as temp: + temp.write(contents) with self.test_session(): read = tf.read_file(temp.name) self.assertEqual([], read.get_shape()) self.assertEqual(read.eval(), contents) + os.remove(temp.name) def testWriteFile(self): cases = ['', 'Some contents'] for contents in cases: contents = tf.compat.as_bytes(contents) - temp = tempfile.NamedTemporaryFile( - prefix='WriteFileTest', dir=self.get_temp_dir()) + with tempfile.NamedTemporaryFile(prefix='WriteFileTest', + dir=self.get_temp_dir(), + delete=False) as temp: + pass with self.test_session() as sess: w = tf.write_file(temp.name, contents) sess.run(w) - file_contents = open(temp.name, 'rb').read() + with open(temp.name, 'rb') as f: + file_contents = f.read() self.assertEqual(file_contents, contents) + os.remove(temp.name) def _subset(self, files, indices): return set(tf.compat.as_bytes(files[i].name) @@ -59,7 +66,7 @@ class IoOpsTest(tf.test.TestCase): cases = ['ABcDEF.GH', 'ABzDEF.GH', 'ABasdfjklDEF.GH', 'AB3DEF.GH', 'AB4DEF.GH', 'ABDEF.GH', 'XYZ'] files = [tempfile.NamedTemporaryFile( - prefix=c, dir=self.get_temp_dir()) for c in cases] + prefix=c, dir=self.get_temp_dir(), delete=True) for c in cases] with self.test_session(): # Test exact match without wildcards. @@ -77,10 +84,16 @@ class IoOpsTest(tf.test.TestCase): self._subset(files, [0, 1, 3, 4])) self.assertEqual(set(tf.matching_files(pattern % '*').eval()), self._subset(files, [0, 1, 2, 3, 4, 5])) - self.assertEqual(set(tf.matching_files(pattern % '[cxz]').eval()), - self._subset(files, [0, 1])) - self.assertEqual(set(tf.matching_files(pattern % '[0-9]').eval()), - self._subset(files, [3, 4])) + # NOTE(mrry): Windows uses PathMatchSpec to match file patterns, which + # does not support the following expressions. + if os.name != 'nt': + self.assertEqual(set(tf.matching_files(pattern % '[cxz]').eval()), + self._subset(files, [0, 1])) + self.assertEqual(set(tf.matching_files(pattern % '[0-9]').eval()), + self._subset(files, [3, 4])) + + for f in files: + f.close() if __name__ == '__main__': |