diff options
author | 2016-09-13 18:55:03 -0800 | |
---|---|---|
committer | 2016-09-13 20:03:08 -0700 | |
commit | 0d13cdf6a1bb6a5d73b340f77338f2c3e1139bc3 (patch) | |
tree | 7d73fc2cd79cb6056b8e1fe152d2faf6f312a63b /tensorflow/python/kernel_tests/reader_ops_test.py | |
parent | 2e9964a2b199fe3e10e3f9b6fe28342ebd621a3b (diff) |
Plumb the existing support for gzip compression in TensorFlow
out to the RecordWriter/Reader, and add tests to validate its compatibility
with external gzip libraries.
Change: 133081474
Diffstat (limited to 'tensorflow/python/kernel_tests/reader_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/reader_ops_test.py | 54 |
1 files changed, 49 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index bdc73e4d51..dd97524357 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +import gzip import os import threading import zlib @@ -27,6 +28,11 @@ import zlib import six import tensorflow as tf +# pylint: disable=invalid-name +TFRecordCompressionType = tf.python_io.TFRecordCompressionType +# pylint: enable=invalid-name + + # Edgar Allan Poe's 'Eldorado' _TEXT = b"""Gaily bedight, A gallant knight, @@ -461,7 +467,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase): fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) filenames.append(fn) options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + compression_type=TFRecordCompressionType.ZLIB) writer = tf.python_io.TFRecordWriter(fn, options=options) for j in range(self._num_records): writer.write(self._Record(i, j)) @@ -493,7 +499,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase): files = self._CreateFiles() with self.test_session() as sess: options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + compression_type=TFRecordCompressionType.ZLIB) reader = tf.TFRecordReader(name="test_reader", options=options) queue = tf.FIFOQueue(99, [tf.string], shapes=()) key, value = reader.read(queue) @@ -535,7 +541,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase): with self.test_session() as sess: options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + compression_type=TFRecordCompressionType.ZLIB) reader = tf.TFRecordReader(name="test_reader", options=options) queue = tf.FIFOQueue(1, [tf.string], shapes=()) key, value = reader.read(queue) @@ -576,6 +582,26 @@ class TFRecordWriterZlibTest(tf.test.TestCase): self.assertEqual(actual, original) + def testGzipReadWrite(self): + """Verify that files produced are gzip compatible.""" + original = [b"foo", b"bar"] + fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") + + # gzip compress the file and write compressed contents to file. + with open(fn, "rb") as f: + cdata = f.read() + gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz") + with gzip.GzipFile(gzfn, "wb") as f: + f.write(cdata) + + actual = [] + for r in tf.python_io.tf_record_iterator( + gzfn, + options=tf.python_io.TFRecordOptions(TFRecordCompressionType.GZIP)): + actual.append(r) + self.assertEqual(actual, original) + + class TFRecordIteratorTest(tf.test.TestCase): def setUp(self): @@ -612,7 +638,7 @@ class TFRecordIteratorTest(tf.test.TestCase): [self._Record(i) for i in range(self._num_records)], "compressed_records") options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + compression_type=TFRecordCompressionType.ZLIB) reader = tf.python_io.tf_record_iterator(fn, options) for i in range(self._num_records): record = next(reader) @@ -631,7 +657,7 @@ class TFRecordIteratorTest(tf.test.TestCase): actual.append(r) self.assertEqual(actual, original) - def testZlibReadWriteLarge(self): + def testWriteZlibReadLarge(self): """Verify compression for large records is zlib library compatible.""" # Make it large (about 5MB) original = [_TEXT * 10240] @@ -643,6 +669,24 @@ class TFRecordIteratorTest(tf.test.TestCase): actual.append(r) self.assertEqual(actual, original) + def testWriteGzipRead(self): + original = [b"foo", b"bar"] + fn = self._WriteCompressedRecordsToFile( + original, + "write_gzip_read.tfrecord.gz", + compression_type=TFRecordCompressionType.GZIP) + + with gzip.GzipFile(fn, "rb") as f: + cdata = f.read() + zfn = os.path.join(self.get_temp_dir(), "tf_record") + with open(zfn, "wb") as f: + f.write(cdata) + + actual = [] + for r in tf.python_io.tf_record_iterator(zfn): + actual.append(r) + self.assertEqual(actual, original) + class AsyncReaderTest(tf.test.TestCase): |