From b6c8def190bc3d885dc7ae9c3c570bdf94031378 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 22 Sep 2016 10:00:16 -0800 Subject: Add support for GZIP (and other future compression types) to TFRecordReader in python. Also clean up Record{Reader,Writer}Options creation to reuse logic and log errors. Change: 133973785 --- tensorflow/core/kernels/tf_record_reader_op.cc | 7 +-- tensorflow/core/lib/io/record_reader.cc | 26 +++++++++++ tensorflow/core/lib/io/record_reader.h | 3 ++ tensorflow/core/lib/io/record_writer.cc | 25 ++++++++++ tensorflow/core/lib/io/record_writer.h | 3 ++ tensorflow/python/kernel_tests/reader_ops_test.py | 56 ++++++++++++++++++++++- tensorflow/python/lib/io/py_record_reader.cc | 11 ++--- tensorflow/python/lib/io/py_record_writer.cc | 11 ++--- tensorflow/python/lib/io/tf_record.py | 26 ++++++----- tensorflow/python/ops/io_ops.py | 10 ++-- 10 files changed, 138 insertions(+), 40 deletions(-) diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc index 30679069ba..e169498fd3 100644 --- a/tensorflow/core/kernels/tf_record_reader_op.cc +++ b/tensorflow/core/kernels/tf_record_reader_op.cc @@ -38,11 +38,8 @@ class TFRecordReader : public ReaderBase { offset_ = 0; TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_)); - io::RecordReaderOptions options; - if (compression_type_ == "ZLIB") { - options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; - } - + io::RecordReaderOptions options = + io::RecordReaderOptions::CreateRecordReaderOptions(compression_type_); reader_.reset(new io::RecordReader(file_.get(), options)); return Status::OK(); } diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index 22801859e8..8cc9d9154c 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -26,6 +26,32 @@ limitations under the License. namespace tensorflow { namespace io { +RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions( + const string& compression_type) { + RecordReaderOptions options; + if (compression_type == "ZLIB") { + options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; +#if defined(IS_SLIM_BUILD) + LOG(ERROR) << "Compression is not supported but compression_type is set." + << " No compression will be used."; +#else + options.zlib_options = io::ZlibCompressionOptions::DEFAULT(); +#endif // IS_SLIM_BUILD + } else if (compression_type == "GZIP") { + options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; +#if defined(IS_SLIM_BUILD) + LOG(ERROR) << "Compression is not supported but compression_type is set." + << " No compression will be used."; +#else + options.zlib_options = io::ZlibCompressionOptions::GZIP(); +#endif // IS_SLIM_BUILD + } else if (compression_type != "") { + LOG(ERROR) << "Unsupported compression_type:" << compression_type + << ". No comprression will be used."; + } + return options; +} + RecordReader::RecordReader(RandomAccessFile* file, const RecordReaderOptions& options) : src_(file), options_(options) { diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index fb675ac98f..6c92b14963 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -37,6 +37,9 @@ class RecordReaderOptions { enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 }; CompressionType compression_type = NONE; + static RecordReaderOptions CreateRecordReaderOptions( + const string& compression_type); + #if !defined(IS_SLIM_BUILD) // Options specific to zlib compression. ZlibCompressionOptions zlib_options; diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 516332d2b7..175bfbd827 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -21,6 +21,31 @@ limitations under the License. namespace tensorflow { namespace io { +RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions( + const string& compression_type) { + RecordWriterOptions options; + if (compression_type == "ZLIB") { + options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION; +#if defined(IS_SLIM_BUILD) + LOG(ERROR) << "Compression is not supported but compression_type is set." + << " No compression will be used."; +#else + options.zlib_options = io::ZlibCompressionOptions::DEFAULT(); +#endif // IS_SLIM_BUILD + } else if (compression_type == "GZIP") { + options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION; +#if defined(IS_SLIM_BUILD) + LOG(ERROR) << "Compression is not supported but compression_type is set." + << " No compression will be used."; +#else + options.zlib_options = io::ZlibCompressionOptions::GZIP(); +#endif // IS_SLIM_BUILD + } else if (compression_type != "") { + LOG(ERROR) << "Unsupported compression_type:" << compression_type + << ". No comprression will be used."; + } + return options; +} RecordWriter::RecordWriter(WritableFile* dest, const RecordWriterOptions& options) diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 3d42a281de..5a2373d757 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -36,6 +36,9 @@ class RecordWriterOptions { enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 }; CompressionType compression_type = NONE; + static RecordWriterOptions CreateRecordWriterOptions( + const string& compression_type); + // Options specific to zlib compression. #if !defined(IS_SLIM_BUILD) ZlibCompressionOptions zlib_options; diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index dd97524357..5a0a4d71ea 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -450,6 +450,60 @@ class TFRecordReaderTest(tf.test.TestCase): self.assertEqual(self._num_files * self._num_records, num_k) self.assertEqual(self._num_files * self._num_records, num_v) + def testReadZlibFiles(self): + files = self._CreateFiles() + zlib_files = [] + for i, fn in enumerate(files): + with open(fn, "rb") as f: + cdata = zlib.compress(f.read()) + + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) + with open(zfn, "wb") as f: + f.write(cdata) + zlib_files.append(zfn) + + with self.test_session() as sess: + options = tf.python_io.TFRecordOptions(TFRecordCompressionType.ZLIB) + reader = tf.TFRecordReader(name="test_reader", options=options) + queue = tf.FIFOQueue(99, [tf.string], shapes=()) + key, value = reader.read(queue) + + queue.enqueue_many([zlib_files]).run() + queue.close().run() + for i in range(self._num_files): + for j in range(self._num_records): + k, v = sess.run([key, value]) + self.assertTrue( + tf.compat.as_text(k).startswith("%s:" % zlib_files[i])) + self.assertAllEqual(self._Record(i, j), v) + + def testReadGzipFiles(self): + files = self._CreateFiles() + gzip_files = [] + for i, fn in enumerate(files): + with open(fn, "rb") as f: + cdata = f.read() + + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) + with gzip.GzipFile(zfn, "wb") as f: + f.write(cdata) + gzip_files.append(zfn) + + with self.test_session() as sess: + options = tf.python_io.TFRecordOptions(TFRecordCompressionType.GZIP) + reader = tf.TFRecordReader(name="test_reader", options=options) + queue = tf.FIFOQueue(99, [tf.string], shapes=()) + key, value = reader.read(queue) + + queue.enqueue_many([gzip_files]).run() + queue.close().run() + for i in range(self._num_files): + for j in range(self._num_records): + k, v = sess.run([key, value]) + self.assertTrue( + tf.compat.as_text(k).startswith("%s:" % gzip_files[i])) + self.assertAllEqual(self._Record(i, j), v) + class TFRecordWriterZlibTest(tf.test.TestCase): @@ -488,7 +542,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase): def _ZlibCompressFile(self, infile, name="tfrecord.z"): # zlib compress the file and write compressed contents to file. with open(infile, "rb") as f: - cdata = zlib.compress(f.read(), 6) + cdata = zlib.compress(f.read()) zfn = os.path.join(self.get_temp_dir(), name) with open(zfn, "wb") as f: diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc index 47c0878932..d3f557506e 100644 --- a/tensorflow/python/lib/io/py_record_reader.cc +++ b/tensorflow/python/lib/io/py_record_reader.cc @@ -43,14 +43,9 @@ PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset, reader->offset_ = start_offset; reader->file_ = file.release(); - RecordReaderOptions options; - if (compression_type_string == "ZLIB") { - options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION; - options.zlib_options = ZlibCompressionOptions::DEFAULT(); - } else if (compression_type_string == "GZIP") { - options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION; - options.zlib_options = ZlibCompressionOptions::GZIP(); - } + RecordReaderOptions options = + RecordReaderOptions::CreateRecordReaderOptions(compression_type_string); + reader->reader_ = new RecordReader(reader->file_, options); return reader; } diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc index d9fdda7ebf..039e59756e 100644 --- a/tensorflow/python/lib/io/py_record_writer.cc +++ b/tensorflow/python/lib/io/py_record_writer.cc @@ -39,14 +39,9 @@ PyRecordWriter* PyRecordWriter::New(const string& filename, PyRecordWriter* writer = new PyRecordWriter; writer->file_ = file.release(); - RecordWriterOptions options; - if (compression_type_string == "ZLIB") { - options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION; - options.zlib_options = ZlibCompressionOptions::DEFAULT(); - } else if (compression_type_string == "GZIP") { - options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION; - options.zlib_options = ZlibCompressionOptions::GZIP(); - } + RecordWriterOptions options = + RecordWriterOptions::CreateRecordWriterOptions(compression_type_string); + writer->writer_ = new RecordWriter(writer->file_, options); return writer; } diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index 96b212c8ad..c07ff5c2d3 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -33,17 +33,21 @@ class TFRecordCompressionType(object): # NOTE(vrv): This will eventually be converted into a proto. to match # the interface used by the C++ RecordWriter. class TFRecordOptions(object): + """Options used for manipulating TFRecord files.""" + compression_type_map = { + TFRecordCompressionType.ZLIB: "ZLIB", + TFRecordCompressionType.GZIP: "GZIP", + TFRecordCompressionType.NONE: "" + } def __init__(self, compression_type): self.compression_type = compression_type - def get_type_as_string(self): - if self.compression_type == TFRecordCompressionType.ZLIB: - return "ZLIB" - elif self.compression_type == TFRecordCompressionType.GZIP: - return "GZIP" - else: + @classmethod + def get_compression_type_string(cls, options): + if not options: return "" + return cls.compression_type_map[options.compression_type] def tf_record_iterator(path, options=None): @@ -59,11 +63,10 @@ def tf_record_iterator(path, options=None): Raises: IOError: If `path` cannot be opened for reading. """ - compression_type_string = options.get_type_as_string() if options else "" + compression_type = TFRecordOptions.get_compression_type_string(options) with errors.raise_exception_on_not_ok_status() as status: reader = pywrap_tensorflow.PyRecordReader_New( - compat.as_bytes(path), 0, compat.as_bytes(compression_type_string), - status) + compat.as_bytes(path), 0, compat.as_bytes(compression_type), status) if reader is None: raise IOError("Could not open %s." % path) @@ -94,12 +97,11 @@ class TFRecordWriter(object): Raises: IOError: If `path` cannot be opened for writing. """ - compression_type_string = options.get_type_as_string() if options else "" + compression_type = TFRecordOptions.get_compression_type_string(options) with errors.raise_exception_on_not_ok_status() as status: self._writer = pywrap_tensorflow.PyRecordWriter_New( - compat.as_bytes(path), compat.as_bytes(compression_type_string), - status) + compat.as_bytes(path), compat.as_bytes(compression_type), status) def __enter__(self): """Enter a `with` block.""" diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 2c34c7ba27..5187242ebe 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -464,13 +464,11 @@ class TFRecordReader(ReaderBase): name: A name for the operation (optional). options: A TFRecordOptions object (optional). """ - compression_type_string = "" - if (options and - options.compression_type == python_io.TFRecordCompressionType.ZLIB): - compression_type_string = "ZLIB" + compression_type = python_io.TFRecordOptions.get_compression_type_string( + options) - rr = gen_io_ops._tf_record_reader(name=name, - compression_type=compression_type_string) + rr = gen_io_ops._tf_record_reader( + name=name, compression_type=compression_type) super(TFRecordReader, self).__init__(rr) -- cgit v1.2.3