aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 17:05:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 17:08:50 -0700
commite001f3ad84f58ace65df4e78941bc49e2ae61967 (patch)
treed5b885811ebdbc899784dafd9f650b141e358432 /tensorflow/python/lib
parentb096c494716b491f0be8fdc504168394d12f6c51 (diff)
Add compression options to Python's TFRecordOptions
Plumb these through to RecordWriterOptions PiperOrigin-RevId: 211894734
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r--tensorflow/python/lib/io/py_record_reader.cc2
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc6
-rw-r--r--tensorflow/python/lib/io/py_record_writer.h5
-rw-r--r--tensorflow/python/lib/io/py_record_writer.i22
-rw-r--r--tensorflow/python/lib/io/tf_record.py108
-rw-r--r--tensorflow/python/lib/io/tf_record_test.py107
6 files changed, 232 insertions, 18 deletions
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc
index 9500fc6a7c..07ce071845 100644
--- a/tensorflow/python/lib/io/py_record_reader.cc
+++ b/tensorflow/python/lib/io/py_record_reader.cc
@@ -30,6 +30,8 @@ namespace io {
PyRecordReader::PyRecordReader() {}
+// NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
+// RecordReaderOptions, if this changes the API can be updated at that time.
PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset,
const string& compression_type_string,
TF_Status* out_status) {
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index e4e5268b0f..faf20df868 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -28,7 +28,7 @@ namespace io {
PyRecordWriter::PyRecordWriter() {}
PyRecordWriter* PyRecordWriter::New(const string& filename,
- const string& compression_type_string,
+ const io::RecordWriterOptions& options,
TF_Status* out_status) {
std::unique_ptr<WritableFile> file;
Status s = Env::Default()->NewWritableFile(filename, &file);
@@ -38,10 +38,6 @@ PyRecordWriter* PyRecordWriter::New(const string& filename,
}
PyRecordWriter* writer = new PyRecordWriter;
writer->file_ = std::move(file);
-
- RecordWriterOptions options =
- RecordWriterOptions::CreateRecordWriterOptions(compression_type_string);
-
writer->writer_.reset(new RecordWriter(writer->file_.get(), options));
return writer;
}
diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h
index 61a4960ee6..9b0792c6db 100644
--- a/tensorflow/python/lib/io/py_record_writer.h
+++ b/tensorflow/python/lib/io/py_record_writer.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -36,10 +37,8 @@ class RecordWriter;
// by multiple threads.
class PyRecordWriter {
public:
- // TODO(vrv): make this take a shared proto to configure
- // the compression options.
static PyRecordWriter* New(const string& filename,
- const string& compression_type_string,
+ const io::RecordWriterOptions& compression_options,
TF_Status* out_status);
~PyRecordWriter();
diff --git a/tensorflow/python/lib/io/py_record_writer.i b/tensorflow/python/lib/io/py_record_writer.i
index 3181c9afce..b2c2bda5dd 100644
--- a/tensorflow/python/lib/io/py_record_writer.i
+++ b/tensorflow/python/lib/io/py_record_writer.i
@@ -18,6 +18,11 @@ limitations under the License.
%include "tensorflow/python/platform/base.i"
%include "tensorflow/python/lib/core/strings.i"
+// Define int8_t explicitly instead of including "stdint.i", since "stdint.h"
+// and "stdint.i" disagree on the definition of int64_t.
+typedef signed char int8;
+%{ typedef signed char int8; %}
+
%feature("except") tensorflow::io::PyRecordWriter::New {
// Let other threads run while we write
Py_BEGIN_ALLOW_THREADS
@@ -26,6 +31,7 @@ limitations under the License.
}
%newobject tensorflow::io::PyRecordWriter::New;
+%newobject tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
%feature("except") tensorflow::io::PyRecordWriter::WriteRecord {
// Let other threads run while we write
@@ -35,6 +41,8 @@ limitations under the License.
}
%{
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/python/lib/io/py_record_writer.h"
%}
@@ -48,7 +56,21 @@ limitations under the License.
%unignore tensorflow::io::PyRecordWriter::Flush;
%unignore tensorflow::io::PyRecordWriter::Close;
%unignore tensorflow::io::PyRecordWriter::New;
+%unignore tensorflow::io::ZlibCompressionOptions;
+%unignore tensorflow::io::ZlibCompressionOptions::flush_mode;
+%unignore tensorflow::io::ZlibCompressionOptions::input_buffer_size;
+%unignore tensorflow::io::ZlibCompressionOptions::output_buffer_size;
+%unignore tensorflow::io::ZlibCompressionOptions::window_bits;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_level;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_method;
+%unignore tensorflow::io::ZlibCompressionOptions::mem_level;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_strategy;
+%unignore tensorflow::io::RecordWriterOptions;
+%unignore tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
+%unignore tensorflow::io::RecordWriterOptions::zlib_options;
+%include "tensorflow/core/lib/io/record_writer.h"
+%include "tensorflow/core/lib/io/zlib_compression_options.h"
%include "tensorflow/python/lib/io/py_record_writer.h"
%unignoreall
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 2b3e986f6b..cce71a2bab 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -33,8 +33,6 @@ class TFRecordCompressionType(object):
GZIP = 2
-# NOTE(vrv): This will eventually be converted into a proto. to match
-# the interface used by the C++ RecordWriter.
@tf_export("python_io.TFRecordOptions")
class TFRecordOptions(object):
"""Options used for manipulating TFRecord files."""
@@ -44,14 +42,105 @@ class TFRecordOptions(object):
TFRecordCompressionType.NONE: ""
}
- def __init__(self, compression_type):
+ def __init__(self,
+ compression_type=None,
+ flush_mode=None,
+ input_buffer_size=None,
+ output_buffer_size=None,
+ window_bits=None,
+ compression_level=None,
+ compression_method=None,
+ mem_level=None,
+ compression_strategy=None):
+ # pylint: disable=line-too-long
+ """Creates a `TFRecordOptions` instance.
+
+ Options only effect TFRecordWriter when compression_type is not `None`.
+ Documentation, details, and defaults can be found in
+ [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h)
+ and in the [zlib manual](http://www.zlib.net/manual.html).
+ Leaving an option as `None` allows C++ to set a reasonable default.
+
+ Args:
+ compression_type: `TFRecordCompressionType` or `None`.
+ flush_mode: flush mode or `None`, Default: Z_NO_FLUSH.
+ input_buffer_size: int or `None`.
+ output_buffer_size: int or `None`.
+ window_bits: int or `None`.
+ compression_level: 0 to 9, or `None`.
+ compression_method: compression method or `None`.
+ mem_level: 1 to 9, or `None`.
+ compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY.
+
+ Returns:
+ A `TFRecordOptions` object.
+
+ Raises:
+ ValueError: If compression_type is invalid.
+ """
+ # pylint: enable=line-too-long
+ # Check compression_type is valid, but for backwards compatibility don't
+ # immediately convert to a string.
+ self.get_compression_type_string(compression_type)
self.compression_type = compression_type
+ self.flush_mode = flush_mode
+ self.input_buffer_size = input_buffer_size
+ self.output_buffer_size = output_buffer_size
+ self.window_bits = window_bits
+ self.compression_level = compression_level
+ self.compression_method = compression_method
+ self.mem_level = mem_level
+ self.compression_strategy = compression_strategy
@classmethod
def get_compression_type_string(cls, options):
+ """Convert various option types to a unified string.
+
+ Args:
+ options: `TFRecordOption`, `TFRecordCompressionType`, or string.
+
+ Returns:
+ Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`).
+
+ Raises:
+ ValueError: If compression_type is invalid.
+ """
if not options:
return ""
- return cls.compression_type_map[options.compression_type]
+ elif isinstance(options, TFRecordOptions):
+ return cls.get_compression_type_string(options.compression_type)
+ elif isinstance(options, TFRecordCompressionType):
+ return cls.compression_type_map[options]
+ elif options in TFRecordOptions.compression_type_map:
+ return cls.compression_type_map[options]
+ elif options in TFRecordOptions.compression_type_map.values():
+ return options
+ else:
+ raise ValueError('Not a valid compression_type: "{}"'.format(options))
+
+ def _as_record_writer_options(self):
+ """Convert to RecordWriterOptions for use with PyRecordWriter."""
+ options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions(
+ compat.as_bytes(
+ self.get_compression_type_string(self.compression_type)))
+
+ if self.flush_mode is not None:
+ options.zlib_options.flush_mode = self.flush_mode
+ if self.input_buffer_size is not None:
+ options.zlib_options.input_buffer_size = self.input_buffer_size
+ if self.output_buffer_size is not None:
+ options.zlib_options.output_buffer_size = self.output_buffer_size
+ if self.window_bits is not None:
+ options.zlib_options.window_bits = self.window_bits
+ if self.compression_level is not None:
+ options.zlib_options.compression_level = self.compression_level
+ if self.compression_method is not None:
+ options.zlib_options.compression_method = self.compression_method
+ if self.mem_level is not None:
+ options.zlib_options.mem_level = self.mem_level
+ if self.compression_strategy is not None:
+ options.zlib_options.compression_strategy = self.compression_strategy
+ return options
@tf_export("python_io.tf_record_iterator")
@@ -100,16 +189,21 @@ class TFRecordWriter(object):
Args:
path: The path to the TFRecords file.
- options: (optional) A TFRecordOptions object.
+ options: (optional) String specifying compression type,
+ `TFRecordCompressionType`, or `TFRecordOptions` object.
Raises:
IOError: If `path` cannot be opened for writing.
+ ValueError: If valid compression_type can't be determined from `options`.
"""
- compression_type = TFRecordOptions.get_compression_type_string(options)
+ if not isinstance(options, TFRecordOptions):
+ options = TFRecordOptions(compression_type=options)
with errors.raise_exception_on_not_ok_status() as status:
+ # pylint: disable=protected-access
self._writer = pywrap_tensorflow.PyRecordWriter_New(
- compat.as_bytes(path), compat.as_bytes(compression_type), status)
+ compat.as_bytes(path), options._as_record_writer_options(), status)
+ # pylint: enable=protected-access
def __enter__(self):
"""Enter a `with` block."""
diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py
index b853b64ae4..def8fe23e5 100644
--- a/tensorflow/python/lib/io/tf_record_test.py
+++ b/tensorflow/python/lib/io/tf_record_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import gzip
import os
+import random
+import string
import zlib
import six
@@ -131,9 +133,6 @@ class TFCompressionTestCase(test.TestCase):
class TFRecordWriterTest(TFCompressionTestCase):
- def setUp(self):
- super(TFRecordWriterTest, self).setUp()
-
def _AssertFilesEqual(self, a, b, equal):
for an, bn in zip(a, b):
with open(an, "rb") as af, open(bn, "rb") as bf:
@@ -142,6 +141,37 @@ class TFRecordWriterTest(TFCompressionTestCase):
else:
self.assertNotEqual(af.read(), bf.read())
+ def _CompressionSizeDelta(self, records, options_a, options_b):
+ """Validate compression with options_a and options_b and return size delta.
+
+ Compress records with options_a and options_b. Uncompress both compressed
+ files and assert that the contents match the original records. Finally
+ calculate how much smaller the file compressed with options_a was than the
+ file compressed with options_b.
+
+ Args:
+ records: The records to compress
+ options_a: First set of options to compress with, the baseline for size.
+ options_b: Second set of options to compress with.
+
+ Returns:
+ The difference in file size when using options_a vs options_b. A positive
+ value means options_a was a better compression than options_b. A negative
+ value means options_b had better compression than options_a.
+
+ """
+
+ fn_a = self._WriteRecordsToFile(records, "tfrecord_a", options=options_a)
+ test_a = list(tf_record.tf_record_iterator(fn_a, options=options_a))
+ self.assertEqual(records, test_a, options_a)
+
+ fn_b = self._WriteRecordsToFile(records, "tfrecord_b", options=options_b)
+ test_b = list(tf_record.tf_record_iterator(fn_b, options=options_b))
+ self.assertEqual(records, test_b, options_b)
+
+ # Negative number => better compression.
+ return os.path.getsize(fn_a) - os.path.getsize(fn_b)
+
def testWriteReadZLibFiles(self):
# Write uncompressed then compress manually.
options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE)
@@ -188,6 +218,76 @@ class TFRecordWriterTest(TFCompressionTestCase):
]
self._AssertFilesEqual(uncompressed_files, files, True)
+ def testNoCompressionType(self):
+ self.assertEqual(
+ "",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions()))
+
+ self.assertEqual(
+ "",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions("")))
+
+ with self.assertRaises(ValueError):
+ tf_record.TFRecordOptions(5)
+
+ with self.assertRaises(ValueError):
+ tf_record.TFRecordOptions("BZ2")
+
+ def testZlibCompressionType(self):
+ zlib_t = tf_record.TFRecordCompressionType.ZLIB
+
+ self.assertEqual(
+ "ZLIB",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions("ZLIB")))
+
+ self.assertEqual(
+ "ZLIB",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions(zlib_t)))
+
+ self.assertEqual(
+ "ZLIB",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions(tf_record.TFRecordOptions(zlib_t))))
+
+ def testCompressionOptions(self):
+ # Create record with mix of random and repeated data to test compression on.
+ rnd = random.Random(123)
+ random_record = compat.as_bytes(
+ "".join(rnd.choice(string.digits) for _ in range(10000)))
+ repeated_record = compat.as_bytes(_TEXT)
+ for _ in range(10000):
+ start_i = rnd.randint(0, len(_TEXT))
+ length = rnd.randint(10, 200)
+ repeated_record += _TEXT[start_i:start_i + length]
+ records = [random_record, repeated_record, random_record]
+
+ tests = [
+ ("compression_level", 2, -1), # Lower compression is worse.
+ ("compression_level", 6, 0), # Default compression_level is equal.
+ ("flush_mode", zlib.Z_FULL_FLUSH, 1), # A few less bytes.
+ ("flush_mode", zlib.Z_NO_FLUSH, 0), # NO_FLUSH is the default.
+ ("input_buffer_size", 4096, 0), # Increases time not size.
+ ("output_buffer_size", 4096, 0), # Increases time not size.
+ ("window_bits", 8, -1), # Smaller than default window increases size.
+ ("compression_strategy", zlib.Z_HUFFMAN_ONLY, -1), # Worse.
+ ("compression_strategy", zlib.Z_FILTERED, -1), # Worse.
+ ]
+
+ compression_type = tf_record.TFRecordCompressionType.ZLIB
+ options_a = tf_record.TFRecordOptions(compression_type)
+ for prop, value, delta_sign in tests:
+ options_b = tf_record.TFRecordOptions(
+ compression_type=compression_type, **{prop: value})
+ delta = self._CompressionSizeDelta(records, options_a, options_b)
+ self.assertTrue(
+ delta == 0 if delta_sign == 0 else delta // delta_sign > 0,
+ "Setting {} = {}, file was {} smaller didn't match sign of {}".format(
+ prop, value, delta, delta_sign))
+
class TFRecordWriterZlibTest(TFCompressionTestCase):
@@ -318,6 +418,7 @@ class TFRecordIteratorTest(TFCompressionTestCase):
for _ in tf_record.tf_record_iterator(fn_truncated):
pass
+
class TFRecordWriterCloseAndFlushTests(test.TestCase):
def setUp(self, compression_type=TFRecordCompressionType.NONE):