diff options
author | 2018-09-06 17:05:04 -0700 | |
---|---|---|
committer | 2018-09-06 17:08:50 -0700 | |
commit | e001f3ad84f58ace65df4e78941bc49e2ae61967 (patch) | |
tree | d5b885811ebdbc899784dafd9f650b141e358432 /tensorflow/python/lib | |
parent | b096c494716b491f0be8fdc504168394d12f6c51 (diff) |
Add compression options to Python's TFRecordOptions
Plumb these through to RecordWriterOptions
PiperOrigin-RevId: 211894734
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r-- | tensorflow/python/lib/io/py_record_reader.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/lib/io/py_record_writer.cc | 6 | ||||
-rw-r--r-- | tensorflow/python/lib/io/py_record_writer.h | 5 | ||||
-rw-r--r-- | tensorflow/python/lib/io/py_record_writer.i | 22 | ||||
-rw-r--r-- | tensorflow/python/lib/io/tf_record.py | 108 | ||||
-rw-r--r-- | tensorflow/python/lib/io/tf_record_test.py | 107 |
6 files changed, 232 insertions, 18 deletions
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc index 9500fc6a7c..07ce071845 100644 --- a/tensorflow/python/lib/io/py_record_reader.cc +++ b/tensorflow/python/lib/io/py_record_reader.cc @@ -30,6 +30,8 @@ namespace io { PyRecordReader::PyRecordReader() {} +// NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking +// RecordReaderOptions, if this changes the API can be updated at that time. PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset, const string& compression_type_string, TF_Status* out_status) { diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc index e4e5268b0f..faf20df868 100644 --- a/tensorflow/python/lib/io/py_record_writer.cc +++ b/tensorflow/python/lib/io/py_record_writer.cc @@ -28,7 +28,7 @@ namespace io { PyRecordWriter::PyRecordWriter() {} PyRecordWriter* PyRecordWriter::New(const string& filename, - const string& compression_type_string, + const io::RecordWriterOptions& options, TF_Status* out_status) { std::unique_ptr<WritableFile> file; Status s = Env::Default()->NewWritableFile(filename, &file); @@ -38,10 +38,6 @@ PyRecordWriter* PyRecordWriter::New(const string& filename, } PyRecordWriter* writer = new PyRecordWriter; writer->file_ = std::move(file); - - RecordWriterOptions options = - RecordWriterOptions::CreateRecordWriterOptions(compression_type_string); - writer->writer_.reset(new RecordWriter(writer->file_.get(), options)); return writer; } diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h index 61a4960ee6..9b0792c6db 100644 --- a/tensorflow/python/lib/io/py_record_writer.h +++ b/tensorflow/python/lib/io/py_record_writer.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -36,10 +37,8 @@ class RecordWriter; // by multiple threads. class PyRecordWriter { public: - // TODO(vrv): make this take a shared proto to configure - // the compression options. static PyRecordWriter* New(const string& filename, - const string& compression_type_string, + const io::RecordWriterOptions& compression_options, TF_Status* out_status); ~PyRecordWriter(); diff --git a/tensorflow/python/lib/io/py_record_writer.i b/tensorflow/python/lib/io/py_record_writer.i index 3181c9afce..b2c2bda5dd 100644 --- a/tensorflow/python/lib/io/py_record_writer.i +++ b/tensorflow/python/lib/io/py_record_writer.i @@ -18,6 +18,11 @@ limitations under the License. %include "tensorflow/python/platform/base.i" %include "tensorflow/python/lib/core/strings.i" +// Define int8_t explicitly instead of including "stdint.i", since "stdint.h" +// and "stdint.i" disagree on the definition of int64_t. +typedef signed char int8; +%{ typedef signed char int8; %} + %feature("except") tensorflow::io::PyRecordWriter::New { // Let other threads run while we write Py_BEGIN_ALLOW_THREADS @@ -26,6 +31,7 @@ limitations under the License. } %newobject tensorflow::io::PyRecordWriter::New; +%newobject tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions; %feature("except") tensorflow::io::PyRecordWriter::WriteRecord { // Let other threads run while we write @@ -35,6 +41,8 @@ limitations under the License. } %{ +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/python/lib/io/py_record_writer.h" %} @@ -48,7 +56,21 @@ limitations under the License. %unignore tensorflow::io::PyRecordWriter::Flush; %unignore tensorflow::io::PyRecordWriter::Close; %unignore tensorflow::io::PyRecordWriter::New; +%unignore tensorflow::io::ZlibCompressionOptions; +%unignore tensorflow::io::ZlibCompressionOptions::flush_mode; +%unignore tensorflow::io::ZlibCompressionOptions::input_buffer_size; +%unignore tensorflow::io::ZlibCompressionOptions::output_buffer_size; +%unignore tensorflow::io::ZlibCompressionOptions::window_bits; +%unignore tensorflow::io::ZlibCompressionOptions::compression_level; +%unignore tensorflow::io::ZlibCompressionOptions::compression_method; +%unignore tensorflow::io::ZlibCompressionOptions::mem_level; +%unignore tensorflow::io::ZlibCompressionOptions::compression_strategy; +%unignore tensorflow::io::RecordWriterOptions; +%unignore tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions; +%unignore tensorflow::io::RecordWriterOptions::zlib_options; +%include "tensorflow/core/lib/io/record_writer.h" +%include "tensorflow/core/lib/io/zlib_compression_options.h" %include "tensorflow/python/lib/io/py_record_writer.h" %unignoreall diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index 2b3e986f6b..cce71a2bab 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -33,8 +33,6 @@ class TFRecordCompressionType(object): GZIP = 2 -# NOTE(vrv): This will eventually be converted into a proto. to match -# the interface used by the C++ RecordWriter. @tf_export("python_io.TFRecordOptions") class TFRecordOptions(object): """Options used for manipulating TFRecord files.""" @@ -44,14 +42,105 @@ class TFRecordOptions(object): TFRecordCompressionType.NONE: "" } - def __init__(self, compression_type): + def __init__(self, + compression_type=None, + flush_mode=None, + input_buffer_size=None, + output_buffer_size=None, + window_bits=None, + compression_level=None, + compression_method=None, + mem_level=None, + compression_strategy=None): + # pylint: disable=line-too-long + """Creates a `TFRecordOptions` instance. + + Options only effect TFRecordWriter when compression_type is not `None`. + Documentation, details, and defaults can be found in + [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h) + and in the [zlib manual](http://www.zlib.net/manual.html). + Leaving an option as `None` allows C++ to set a reasonable default. + + Args: + compression_type: `TFRecordCompressionType` or `None`. + flush_mode: flush mode or `None`, Default: Z_NO_FLUSH. + input_buffer_size: int or `None`. + output_buffer_size: int or `None`. + window_bits: int or `None`. + compression_level: 0 to 9, or `None`. + compression_method: compression method or `None`. + mem_level: 1 to 9, or `None`. + compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY. + + Returns: + A `TFRecordOptions` object. + + Raises: + ValueError: If compression_type is invalid. + """ + # pylint: enable=line-too-long + # Check compression_type is valid, but for backwards compatibility don't + # immediately convert to a string. + self.get_compression_type_string(compression_type) self.compression_type = compression_type + self.flush_mode = flush_mode + self.input_buffer_size = input_buffer_size + self.output_buffer_size = output_buffer_size + self.window_bits = window_bits + self.compression_level = compression_level + self.compression_method = compression_method + self.mem_level = mem_level + self.compression_strategy = compression_strategy @classmethod def get_compression_type_string(cls, options): + """Convert various option types to a unified string. + + Args: + options: `TFRecordOption`, `TFRecordCompressionType`, or string. + + Returns: + Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`). + + Raises: + ValueError: If compression_type is invalid. + """ if not options: return "" - return cls.compression_type_map[options.compression_type] + elif isinstance(options, TFRecordOptions): + return cls.get_compression_type_string(options.compression_type) + elif isinstance(options, TFRecordCompressionType): + return cls.compression_type_map[options] + elif options in TFRecordOptions.compression_type_map: + return cls.compression_type_map[options] + elif options in TFRecordOptions.compression_type_map.values(): + return options + else: + raise ValueError('Not a valid compression_type: "{}"'.format(options)) + + def _as_record_writer_options(self): + """Convert to RecordWriterOptions for use with PyRecordWriter.""" + options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions( + compat.as_bytes( + self.get_compression_type_string(self.compression_type))) + + if self.flush_mode is not None: + options.zlib_options.flush_mode = self.flush_mode + if self.input_buffer_size is not None: + options.zlib_options.input_buffer_size = self.input_buffer_size + if self.output_buffer_size is not None: + options.zlib_options.output_buffer_size = self.output_buffer_size + if self.window_bits is not None: + options.zlib_options.window_bits = self.window_bits + if self.compression_level is not None: + options.zlib_options.compression_level = self.compression_level + if self.compression_method is not None: + options.zlib_options.compression_method = self.compression_method + if self.mem_level is not None: + options.zlib_options.mem_level = self.mem_level + if self.compression_strategy is not None: + options.zlib_options.compression_strategy = self.compression_strategy + return options @tf_export("python_io.tf_record_iterator") @@ -100,16 +189,21 @@ class TFRecordWriter(object): Args: path: The path to the TFRecords file. - options: (optional) A TFRecordOptions object. + options: (optional) String specifying compression type, + `TFRecordCompressionType`, or `TFRecordOptions` object. Raises: IOError: If `path` cannot be opened for writing. + ValueError: If valid compression_type can't be determined from `options`. """ - compression_type = TFRecordOptions.get_compression_type_string(options) + if not isinstance(options, TFRecordOptions): + options = TFRecordOptions(compression_type=options) with errors.raise_exception_on_not_ok_status() as status: + # pylint: disable=protected-access self._writer = pywrap_tensorflow.PyRecordWriter_New( - compat.as_bytes(path), compat.as_bytes(compression_type), status) + compat.as_bytes(path), options._as_record_writer_options(), status) + # pylint: enable=protected-access def __enter__(self): """Enter a `with` block.""" diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py index b853b64ae4..def8fe23e5 100644 --- a/tensorflow/python/lib/io/tf_record_test.py +++ b/tensorflow/python/lib/io/tf_record_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import gzip import os +import random +import string import zlib import six @@ -131,9 +133,6 @@ class TFCompressionTestCase(test.TestCase): class TFRecordWriterTest(TFCompressionTestCase): - def setUp(self): - super(TFRecordWriterTest, self).setUp() - def _AssertFilesEqual(self, a, b, equal): for an, bn in zip(a, b): with open(an, "rb") as af, open(bn, "rb") as bf: @@ -142,6 +141,37 @@ class TFRecordWriterTest(TFCompressionTestCase): else: self.assertNotEqual(af.read(), bf.read()) + def _CompressionSizeDelta(self, records, options_a, options_b): + """Validate compression with options_a and options_b and return size delta. + + Compress records with options_a and options_b. Uncompress both compressed + files and assert that the contents match the original records. Finally + calculate how much smaller the file compressed with options_a was than the + file compressed with options_b. + + Args: + records: The records to compress + options_a: First set of options to compress with, the baseline for size. + options_b: Second set of options to compress with. + + Returns: + The difference in file size when using options_a vs options_b. A positive + value means options_a was a better compression than options_b. A negative + value means options_b had better compression than options_a. + + """ + + fn_a = self._WriteRecordsToFile(records, "tfrecord_a", options=options_a) + test_a = list(tf_record.tf_record_iterator(fn_a, options=options_a)) + self.assertEqual(records, test_a, options_a) + + fn_b = self._WriteRecordsToFile(records, "tfrecord_b", options=options_b) + test_b = list(tf_record.tf_record_iterator(fn_b, options=options_b)) + self.assertEqual(records, test_b, options_b) + + # Negative number => better compression. + return os.path.getsize(fn_a) - os.path.getsize(fn_b) + def testWriteReadZLibFiles(self): # Write uncompressed then compress manually. options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) @@ -188,6 +218,76 @@ class TFRecordWriterTest(TFCompressionTestCase): ] self._AssertFilesEqual(uncompressed_files, files, True) + def testNoCompressionType(self): + self.assertEqual( + "", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions())) + + self.assertEqual( + "", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions(""))) + + with self.assertRaises(ValueError): + tf_record.TFRecordOptions(5) + + with self.assertRaises(ValueError): + tf_record.TFRecordOptions("BZ2") + + def testZlibCompressionType(self): + zlib_t = tf_record.TFRecordCompressionType.ZLIB + + self.assertEqual( + "ZLIB", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions("ZLIB"))) + + self.assertEqual( + "ZLIB", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions(zlib_t))) + + self.assertEqual( + "ZLIB", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions(tf_record.TFRecordOptions(zlib_t)))) + + def testCompressionOptions(self): + # Create record with mix of random and repeated data to test compression on. + rnd = random.Random(123) + random_record = compat.as_bytes( + "".join(rnd.choice(string.digits) for _ in range(10000))) + repeated_record = compat.as_bytes(_TEXT) + for _ in range(10000): + start_i = rnd.randint(0, len(_TEXT)) + length = rnd.randint(10, 200) + repeated_record += _TEXT[start_i:start_i + length] + records = [random_record, repeated_record, random_record] + + tests = [ + ("compression_level", 2, -1), # Lower compression is worse. + ("compression_level", 6, 0), # Default compression_level is equal. + ("flush_mode", zlib.Z_FULL_FLUSH, 1), # A few less bytes. + ("flush_mode", zlib.Z_NO_FLUSH, 0), # NO_FLUSH is the default. + ("input_buffer_size", 4096, 0), # Increases time not size. + ("output_buffer_size", 4096, 0), # Increases time not size. + ("window_bits", 8, -1), # Smaller than default window increases size. + ("compression_strategy", zlib.Z_HUFFMAN_ONLY, -1), # Worse. + ("compression_strategy", zlib.Z_FILTERED, -1), # Worse. + ] + + compression_type = tf_record.TFRecordCompressionType.ZLIB + options_a = tf_record.TFRecordOptions(compression_type) + for prop, value, delta_sign in tests: + options_b = tf_record.TFRecordOptions( + compression_type=compression_type, **{prop: value}) + delta = self._CompressionSizeDelta(records, options_a, options_b) + self.assertTrue( + delta == 0 if delta_sign == 0 else delta // delta_sign > 0, + "Setting {} = {}, file was {} smaller didn't match sign of {}".format( + prop, value, delta, delta_sign)) + class TFRecordWriterZlibTest(TFCompressionTestCase): @@ -318,6 +418,7 @@ class TFRecordIteratorTest(TFCompressionTestCase): for _ in tf_record.tf_record_iterator(fn_truncated): pass + class TFRecordWriterCloseAndFlushTests(test.TestCase): def setUp(self, compression_type=TFRecordCompressionType.NONE): |