diff options
-rw-r--r-- | tensorflow/python/kernel_tests/reader_ops_test.py | 54 | ||||
-rw-r--r-- | tensorflow/python/lib/io/py_record_reader.cc | 5 | ||||
-rw-r--r-- | tensorflow/python/lib/io/py_record_writer.cc | 5 | ||||
-rw-r--r-- | tensorflow/python/lib/io/tf_record.py | 21 |
4 files changed, 71 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index bdc73e4d51..dd97524357 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +import gzip import os import threading import zlib @@ -27,6 +28,11 @@ import zlib import six import tensorflow as tf +# pylint: disable=invalid-name +TFRecordCompressionType = tf.python_io.TFRecordCompressionType +# pylint: enable=invalid-name + + # Edgar Allan Poe's 'Eldorado' _TEXT = b"""Gaily bedight, A gallant knight, @@ -461,7 +467,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase): fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) filenames.append(fn) options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + compression_type=TFRecordCompressionType.ZLIB) writer = tf.python_io.TFRecordWriter(fn, options=options) for j in range(self._num_records): writer.write(self._Record(i, j)) @@ -493,7 +499,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase): files = self._CreateFiles() with self.test_session() as sess: options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + compression_type=TFRecordCompressionType.ZLIB) reader = tf.TFRecordReader(name="test_reader", options=options) queue = tf.FIFOQueue(99, [tf.string], shapes=()) key, value = reader.read(queue) @@ -535,7 +541,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase): with self.test_session() as sess: options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + compression_type=TFRecordCompressionType.ZLIB) reader = tf.TFRecordReader(name="test_reader", options=options) queue = tf.FIFOQueue(1, [tf.string], shapes=()) key, value = reader.read(queue) @@ -576,6 +582,26 @@ class TFRecordWriterZlibTest(tf.test.TestCase): self.assertEqual(actual, original) + def testGzipReadWrite(self): + """Verify that files produced are gzip compatible.""" + original = [b"foo", b"bar"] + fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") + + # gzip compress the file and write compressed contents to file. + with open(fn, "rb") as f: + cdata = f.read() + gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz") + with gzip.GzipFile(gzfn, "wb") as f: + f.write(cdata) + + actual = [] + for r in tf.python_io.tf_record_iterator( + gzfn, + options=tf.python_io.TFRecordOptions(TFRecordCompressionType.GZIP)): + actual.append(r) + self.assertEqual(actual, original) + + class TFRecordIteratorTest(tf.test.TestCase): def setUp(self): @@ -612,7 +638,7 @@ class TFRecordIteratorTest(tf.test.TestCase): [self._Record(i) for i in range(self._num_records)], "compressed_records") options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + compression_type=TFRecordCompressionType.ZLIB) reader = tf.python_io.tf_record_iterator(fn, options) for i in range(self._num_records): record = next(reader) @@ -631,7 +657,7 @@ class TFRecordIteratorTest(tf.test.TestCase): actual.append(r) self.assertEqual(actual, original) - def testZlibReadWriteLarge(self): + def testWriteZlibReadLarge(self): """Verify compression for large records is zlib library compatible.""" # Make it large (about 5MB) original = [_TEXT * 10240] @@ -643,6 +669,24 @@ class TFRecordIteratorTest(tf.test.TestCase): actual.append(r) self.assertEqual(actual, original) + def testWriteGzipRead(self): + original = [b"foo", b"bar"] + fn = self._WriteCompressedRecordsToFile( + original, + "write_gzip_read.tfrecord.gz", + compression_type=TFRecordCompressionType.GZIP) + + with gzip.GzipFile(fn, "rb") as f: + cdata = f.read() + zfn = os.path.join(self.get_temp_dir(), "tf_record") + with open(zfn, "wb") as f: + f.write(cdata) + + actual = [] + for r in tf.python_io.tf_record_iterator(zfn): + actual.append(r) + self.assertEqual(actual, original) + class AsyncReaderTest(tf.test.TestCase): diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc index 568f9a8e61..552cca6a4b 100644 --- a/tensorflow/python/lib/io/py_record_reader.cc +++ b/tensorflow/python/lib/io/py_record_reader.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" @@ -42,6 +43,10 @@ PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset, RecordReaderOptions options; if (compression_type_string == "ZLIB") { options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION; + options.zlib_options = ZlibCompressionOptions::DEFAULT(); + } else if (compression_type_string == "GZIP") { + options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION; + options.zlib_options = ZlibCompressionOptions::GZIP(); } reader->reader_ = new RecordReader(reader->file_, options); return reader; diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc index 73e77909fe..019f5e2fac 100644 --- a/tensorflow/python/lib/io/py_record_writer.cc +++ b/tensorflow/python/lib/io/py_record_writer.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" @@ -38,6 +39,10 @@ PyRecordWriter* PyRecordWriter::New(const string& filename, RecordWriterOptions options; if (compression_type_string == "ZLIB") { options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION; + options.zlib_options = ZlibCompressionOptions::DEFAULT(); + } else if (compression_type_string == "GZIP") { + options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION; + options.zlib_options = ZlibCompressionOptions::GZIP(); } writer->writer_ = new RecordWriter(writer->file_, options); return writer; diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index 4694dee5f6..2419a1ef1a 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -26,6 +26,7 @@ from tensorflow.python.util import compat class TFRecordCompressionType(object): NONE = 0 ZLIB = 1 + GZIP = 2 # NOTE(vrv): This will eventually be converted into a proto. to match @@ -35,6 +36,14 @@ class TFRecordOptions(object): def __init__(self, compression_type): self.compression_type = compression_type + def get_type_as_string(self): + if self.compression_type == TFRecordCompressionType.ZLIB: + return "ZLIB" + elif self.compression_type == TFRecordCompressionType.GZIP: + return "GZIP" + else: + return "" + def tf_record_iterator(path, options=None): """An iterator that read the records from a TFRecords file. @@ -49,11 +58,7 @@ def tf_record_iterator(path, options=None): Raises: IOError: If `path` cannot be opened for reading. """ - compression_type_string = "" - if options: - if options.compression_type == TFRecordCompressionType.ZLIB: - compression_type_string = "ZLIB" - + compression_type_string = options.get_type_as_string() if options else "" reader = pywrap_tensorflow.PyRecordReader_New( compat.as_bytes(path), 0, compat.as_bytes(compression_type_string)) @@ -74,6 +79,7 @@ class TFRecordWriter(object): @@write @@close """ + # TODO(josh11b): Support appending? def __init__(self, path, options=None): """Opens file `path` and creates a `TFRecordWriter` writing to it. @@ -85,10 +91,7 @@ class TFRecordWriter(object): Raises: IOError: If `path` cannot be opened for writing. """ - compression_type_string = "" - if options: - if options.compression_type == TFRecordCompressionType.ZLIB: - compression_type_string = "ZLIB" + compression_type_string = options.get_type_as_string() if options else "" self._writer = pywrap_tensorflow.PyRecordWriter_New( compat.as_bytes(path), compat.as_bytes(compression_type_string)) |