aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-22 10:00:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-22 11:02:46 -0700
commitb6c8def190bc3d885dc7ae9c3c570bdf94031378 (patch)
tree20c72c0d06f204e5ca2e2a23b75aef65a180f0e1
parente4a63b578f97c9dca26fd4d3a364f90a94cb45b5 (diff)
Add support for GZIP (and other future compression types) to TFRecordReader in python.
Also clean up Record{Reader,Writer}Options creation to reuse logic and log errors. Change: 133973785
-rw-r--r--tensorflow/core/kernels/tf_record_reader_op.cc7
-rw-r--r--tensorflow/core/lib/io/record_reader.cc26
-rw-r--r--tensorflow/core/lib/io/record_reader.h3
-rw-r--r--tensorflow/core/lib/io/record_writer.cc25
-rw-r--r--tensorflow/core/lib/io/record_writer.h3
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py56
-rw-r--r--tensorflow/python/lib/io/py_record_reader.cc11
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc11
-rw-r--r--tensorflow/python/lib/io/tf_record.py26
-rw-r--r--tensorflow/python/ops/io_ops.py10
10 files changed, 138 insertions, 40 deletions
diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc
index 30679069ba..e169498fd3 100644
--- a/tensorflow/core/kernels/tf_record_reader_op.cc
+++ b/tensorflow/core/kernels/tf_record_reader_op.cc
@@ -38,11 +38,8 @@ class TFRecordReader : public ReaderBase {
offset_ = 0;
TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_));
- io::RecordReaderOptions options;
- if (compression_type_ == "ZLIB") {
- options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
- }
-
+ io::RecordReaderOptions options =
+ io::RecordReaderOptions::CreateRecordReaderOptions(compression_type_);
reader_.reset(new io::RecordReader(file_.get(), options));
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index 22801859e8..8cc9d9154c 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -26,6 +26,32 @@ limitations under the License.
namespace tensorflow {
namespace io {
+RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions(
+ const string& compression_type) {
+ RecordReaderOptions options;
+ if (compression_type == "ZLIB") {
+ options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
+#if defined(IS_SLIM_BUILD)
+ LOG(ERROR) << "Compression is not supported but compression_type is set."
+ << " No compression will be used.";
+#else
+ options.zlib_options = io::ZlibCompressionOptions::DEFAULT();
+#endif // IS_SLIM_BUILD
+ } else if (compression_type == "GZIP") {
+ options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
+#if defined(IS_SLIM_BUILD)
+ LOG(ERROR) << "Compression is not supported but compression_type is set."
+ << " No compression will be used.";
+#else
+ options.zlib_options = io::ZlibCompressionOptions::GZIP();
+#endif // IS_SLIM_BUILD
+ } else if (compression_type != "") {
+ LOG(ERROR) << "Unsupported compression_type:" << compression_type
+ << ". No comprression will be used.";
+ }
+ return options;
+}
+
RecordReader::RecordReader(RandomAccessFile* file,
const RecordReaderOptions& options)
: src_(file), options_(options) {
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index fb675ac98f..6c92b14963 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -37,6 +37,9 @@ class RecordReaderOptions {
enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 };
CompressionType compression_type = NONE;
+ static RecordReaderOptions CreateRecordReaderOptions(
+ const string& compression_type);
+
#if !defined(IS_SLIM_BUILD)
// Options specific to zlib compression.
ZlibCompressionOptions zlib_options;
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index 516332d2b7..175bfbd827 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -21,6 +21,31 @@ limitations under the License.
namespace tensorflow {
namespace io {
+RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions(
+ const string& compression_type) {
+ RecordWriterOptions options;
+ if (compression_type == "ZLIB") {
+ options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION;
+#if defined(IS_SLIM_BUILD)
+ LOG(ERROR) << "Compression is not supported but compression_type is set."
+ << " No compression will be used.";
+#else
+ options.zlib_options = io::ZlibCompressionOptions::DEFAULT();
+#endif // IS_SLIM_BUILD
+ } else if (compression_type == "GZIP") {
+ options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION;
+#if defined(IS_SLIM_BUILD)
+ LOG(ERROR) << "Compression is not supported but compression_type is set."
+ << " No compression will be used.";
+#else
+ options.zlib_options = io::ZlibCompressionOptions::GZIP();
+#endif // IS_SLIM_BUILD
+ } else if (compression_type != "") {
+ LOG(ERROR) << "Unsupported compression_type:" << compression_type
+ << ". No comprression will be used.";
+ }
+ return options;
+}
RecordWriter::RecordWriter(WritableFile* dest,
const RecordWriterOptions& options)
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
index 3d42a281de..5a2373d757 100644
--- a/tensorflow/core/lib/io/record_writer.h
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -36,6 +36,9 @@ class RecordWriterOptions {
enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 };
CompressionType compression_type = NONE;
+ static RecordWriterOptions CreateRecordWriterOptions(
+ const string& compression_type);
+
// Options specific to zlib compression.
#if !defined(IS_SLIM_BUILD)
ZlibCompressionOptions zlib_options;
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index dd97524357..5a0a4d71ea 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -450,6 +450,60 @@ class TFRecordReaderTest(tf.test.TestCase):
self.assertEqual(self._num_files * self._num_records, num_k)
self.assertEqual(self._num_files * self._num_records, num_v)
+ def testReadZlibFiles(self):
+ files = self._CreateFiles()
+ zlib_files = []
+ for i, fn in enumerate(files):
+ with open(fn, "rb") as f:
+ cdata = zlib.compress(f.read())
+
+ zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+ zlib_files.append(zfn)
+
+ with self.test_session() as sess:
+ options = tf.python_io.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ reader = tf.TFRecordReader(name="test_reader", options=options)
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ key, value = reader.read(queue)
+
+ queue.enqueue_many([zlib_files]).run()
+ queue.close().run()
+ for i in range(self._num_files):
+ for j in range(self._num_records):
+ k, v = sess.run([key, value])
+ self.assertTrue(
+ tf.compat.as_text(k).startswith("%s:" % zlib_files[i]))
+ self.assertAllEqual(self._Record(i, j), v)
+
+ def testReadGzipFiles(self):
+ files = self._CreateFiles()
+ gzip_files = []
+ for i, fn in enumerate(files):
+ with open(fn, "rb") as f:
+ cdata = f.read()
+
+ zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
+ with gzip.GzipFile(zfn, "wb") as f:
+ f.write(cdata)
+ gzip_files.append(zfn)
+
+ with self.test_session() as sess:
+ options = tf.python_io.TFRecordOptions(TFRecordCompressionType.GZIP)
+ reader = tf.TFRecordReader(name="test_reader", options=options)
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ key, value = reader.read(queue)
+
+ queue.enqueue_many([gzip_files]).run()
+ queue.close().run()
+ for i in range(self._num_files):
+ for j in range(self._num_records):
+ k, v = sess.run([key, value])
+ self.assertTrue(
+ tf.compat.as_text(k).startswith("%s:" % gzip_files[i]))
+ self.assertAllEqual(self._Record(i, j), v)
+
class TFRecordWriterZlibTest(tf.test.TestCase):
@@ -488,7 +542,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase):
def _ZlibCompressFile(self, infile, name="tfrecord.z"):
# zlib compress the file and write compressed contents to file.
with open(infile, "rb") as f:
- cdata = zlib.compress(f.read(), 6)
+ cdata = zlib.compress(f.read())
zfn = os.path.join(self.get_temp_dir(), name)
with open(zfn, "wb") as f:
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc
index 47c0878932..d3f557506e 100644
--- a/tensorflow/python/lib/io/py_record_reader.cc
+++ b/tensorflow/python/lib/io/py_record_reader.cc
@@ -43,14 +43,9 @@ PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset,
reader->offset_ = start_offset;
reader->file_ = file.release();
- RecordReaderOptions options;
- if (compression_type_string == "ZLIB") {
- options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION;
- options.zlib_options = ZlibCompressionOptions::DEFAULT();
- } else if (compression_type_string == "GZIP") {
- options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION;
- options.zlib_options = ZlibCompressionOptions::GZIP();
- }
+ RecordReaderOptions options =
+ RecordReaderOptions::CreateRecordReaderOptions(compression_type_string);
+
reader->reader_ = new RecordReader(reader->file_, options);
return reader;
}
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index d9fdda7ebf..039e59756e 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -39,14 +39,9 @@ PyRecordWriter* PyRecordWriter::New(const string& filename,
PyRecordWriter* writer = new PyRecordWriter;
writer->file_ = file.release();
- RecordWriterOptions options;
- if (compression_type_string == "ZLIB") {
- options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION;
- options.zlib_options = ZlibCompressionOptions::DEFAULT();
- } else if (compression_type_string == "GZIP") {
- options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION;
- options.zlib_options = ZlibCompressionOptions::GZIP();
- }
+ RecordWriterOptions options =
+ RecordWriterOptions::CreateRecordWriterOptions(compression_type_string);
+
writer->writer_ = new RecordWriter(writer->file_, options);
return writer;
}
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 96b212c8ad..c07ff5c2d3 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -33,17 +33,21 @@ class TFRecordCompressionType(object):
# NOTE(vrv): This will eventually be converted into a proto. to match
# the interface used by the C++ RecordWriter.
class TFRecordOptions(object):
+ """Options used for manipulating TFRecord files."""
+ compression_type_map = {
+ TFRecordCompressionType.ZLIB: "ZLIB",
+ TFRecordCompressionType.GZIP: "GZIP",
+ TFRecordCompressionType.NONE: ""
+ }
def __init__(self, compression_type):
self.compression_type = compression_type
- def get_type_as_string(self):
- if self.compression_type == TFRecordCompressionType.ZLIB:
- return "ZLIB"
- elif self.compression_type == TFRecordCompressionType.GZIP:
- return "GZIP"
- else:
+ @classmethod
+ def get_compression_type_string(cls, options):
+ if not options:
return ""
+ return cls.compression_type_map[options.compression_type]
def tf_record_iterator(path, options=None):
@@ -59,11 +63,10 @@ def tf_record_iterator(path, options=None):
Raises:
IOError: If `path` cannot be opened for reading.
"""
- compression_type_string = options.get_type_as_string() if options else ""
+ compression_type = TFRecordOptions.get_compression_type_string(options)
with errors.raise_exception_on_not_ok_status() as status:
reader = pywrap_tensorflow.PyRecordReader_New(
- compat.as_bytes(path), 0, compat.as_bytes(compression_type_string),
- status)
+ compat.as_bytes(path), 0, compat.as_bytes(compression_type), status)
if reader is None:
raise IOError("Could not open %s." % path)
@@ -94,12 +97,11 @@ class TFRecordWriter(object):
Raises:
IOError: If `path` cannot be opened for writing.
"""
- compression_type_string = options.get_type_as_string() if options else ""
+ compression_type = TFRecordOptions.get_compression_type_string(options)
with errors.raise_exception_on_not_ok_status() as status:
self._writer = pywrap_tensorflow.PyRecordWriter_New(
- compat.as_bytes(path), compat.as_bytes(compression_type_string),
- status)
+ compat.as_bytes(path), compat.as_bytes(compression_type), status)
def __enter__(self):
"""Enter a `with` block."""
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 2c34c7ba27..5187242ebe 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -464,13 +464,11 @@ class TFRecordReader(ReaderBase):
name: A name for the operation (optional).
options: A TFRecordOptions object (optional).
"""
- compression_type_string = ""
- if (options and
- options.compression_type == python_io.TFRecordCompressionType.ZLIB):
- compression_type_string = "ZLIB"
+ compression_type = python_io.TFRecordOptions.get_compression_type_string(
+ options)
- rr = gen_io_ops._tf_record_reader(name=name,
- compression_type=compression_type_string)
+ rr = gen_io_ops._tf_record_reader(
+ name=name, compression_type=compression_type)
super(TFRecordReader, self).__init__(rr)