aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py54
-rw-r--r--tensorflow/python/lib/io/py_record_reader.cc5
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc5
-rw-r--r--tensorflow/python/lib/io/tf_record.py21
4 files changed, 71 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index bdc73e4d51..dd97524357 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import gzip
import os
import threading
import zlib
@@ -27,6 +28,11 @@ import zlib
import six
import tensorflow as tf
+# pylint: disable=invalid-name
+TFRecordCompressionType = tf.python_io.TFRecordCompressionType
+# pylint: enable=invalid-name
+
+
# Edgar Allan Poe's 'Eldorado'
_TEXT = b"""Gaily bedight,
A gallant knight,
@@ -461,7 +467,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase):
fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
filenames.append(fn)
options = tf.python_io.TFRecordOptions(
- compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ compression_type=TFRecordCompressionType.ZLIB)
writer = tf.python_io.TFRecordWriter(fn, options=options)
for j in range(self._num_records):
writer.write(self._Record(i, j))
@@ -493,7 +499,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase):
files = self._CreateFiles()
with self.test_session() as sess:
options = tf.python_io.TFRecordOptions(
- compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ compression_type=TFRecordCompressionType.ZLIB)
reader = tf.TFRecordReader(name="test_reader", options=options)
queue = tf.FIFOQueue(99, [tf.string], shapes=())
key, value = reader.read(queue)
@@ -535,7 +541,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase):
with self.test_session() as sess:
options = tf.python_io.TFRecordOptions(
- compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ compression_type=TFRecordCompressionType.ZLIB)
reader = tf.TFRecordReader(name="test_reader", options=options)
queue = tf.FIFOQueue(1, [tf.string], shapes=())
key, value = reader.read(queue)
@@ -576,6 +582,26 @@ class TFRecordWriterZlibTest(tf.test.TestCase):
self.assertEqual(actual, original)
+ def testGzipReadWrite(self):
+ """Verify that files produced are gzip compatible."""
+ original = [b"foo", b"bar"]
+ fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord")
+
+ # gzip compress the file and write compressed contents to file.
+ with open(fn, "rb") as f:
+ cdata = f.read()
+ gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz")
+ with gzip.GzipFile(gzfn, "wb") as f:
+ f.write(cdata)
+
+ actual = []
+ for r in tf.python_io.tf_record_iterator(
+ gzfn,
+ options=tf.python_io.TFRecordOptions(TFRecordCompressionType.GZIP)):
+ actual.append(r)
+ self.assertEqual(actual, original)
+
+
class TFRecordIteratorTest(tf.test.TestCase):
def setUp(self):
@@ -612,7 +638,7 @@ class TFRecordIteratorTest(tf.test.TestCase):
[self._Record(i) for i in range(self._num_records)],
"compressed_records")
options = tf.python_io.TFRecordOptions(
- compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ compression_type=TFRecordCompressionType.ZLIB)
reader = tf.python_io.tf_record_iterator(fn, options)
for i in range(self._num_records):
record = next(reader)
@@ -631,7 +657,7 @@ class TFRecordIteratorTest(tf.test.TestCase):
actual.append(r)
self.assertEqual(actual, original)
- def testZlibReadWriteLarge(self):
+ def testWriteZlibReadLarge(self):
"""Verify compression for large records is zlib library compatible."""
# Make it large (about 5MB)
original = [_TEXT * 10240]
@@ -643,6 +669,24 @@ class TFRecordIteratorTest(tf.test.TestCase):
actual.append(r)
self.assertEqual(actual, original)
+ def testWriteGzipRead(self):
+ original = [b"foo", b"bar"]
+ fn = self._WriteCompressedRecordsToFile(
+ original,
+ "write_gzip_read.tfrecord.gz",
+ compression_type=TFRecordCompressionType.GZIP)
+
+ with gzip.GzipFile(fn, "rb") as f:
+ cdata = f.read()
+ zfn = os.path.join(self.get_temp_dir(), "tf_record")
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+
+ actual = []
+ for r in tf.python_io.tf_record_iterator(zfn):
+ actual.append(r)
+ self.assertEqual(actual, original)
+
class AsyncReaderTest(tf.test.TestCase):
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc
index 568f9a8e61..552cca6a4b 100644
--- a/tensorflow/python/lib/io/py_record_reader.cc
+++ b/tensorflow/python/lib/io/py_record_reader.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/types.h"
@@ -42,6 +43,10 @@ PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset,
RecordReaderOptions options;
if (compression_type_string == "ZLIB") {
options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION;
+ options.zlib_options = ZlibCompressionOptions::DEFAULT();
+ } else if (compression_type_string == "GZIP") {
+ options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION;
+ options.zlib_options = ZlibCompressionOptions::GZIP();
}
reader->reader_ = new RecordReader(reader->file_, options);
return reader;
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index 73e77909fe..019f5e2fac 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/types.h"
@@ -38,6 +39,10 @@ PyRecordWriter* PyRecordWriter::New(const string& filename,
RecordWriterOptions options;
if (compression_type_string == "ZLIB") {
options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION;
+ options.zlib_options = ZlibCompressionOptions::DEFAULT();
+ } else if (compression_type_string == "GZIP") {
+ options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION;
+ options.zlib_options = ZlibCompressionOptions::GZIP();
}
writer->writer_ = new RecordWriter(writer->file_, options);
return writer;
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 4694dee5f6..2419a1ef1a 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -26,6 +26,7 @@ from tensorflow.python.util import compat
class TFRecordCompressionType(object):
NONE = 0
ZLIB = 1
+ GZIP = 2
# NOTE(vrv): This will eventually be converted into a proto. to match
@@ -35,6 +36,14 @@ class TFRecordOptions(object):
def __init__(self, compression_type):
self.compression_type = compression_type
+ def get_type_as_string(self):
+ if self.compression_type == TFRecordCompressionType.ZLIB:
+ return "ZLIB"
+ elif self.compression_type == TFRecordCompressionType.GZIP:
+ return "GZIP"
+ else:
+ return ""
+
def tf_record_iterator(path, options=None):
"""An iterator that read the records from a TFRecords file.
@@ -49,11 +58,7 @@ def tf_record_iterator(path, options=None):
Raises:
IOError: If `path` cannot be opened for reading.
"""
- compression_type_string = ""
- if options:
- if options.compression_type == TFRecordCompressionType.ZLIB:
- compression_type_string = "ZLIB"
-
+ compression_type_string = options.get_type_as_string() if options else ""
reader = pywrap_tensorflow.PyRecordReader_New(
compat.as_bytes(path), 0, compat.as_bytes(compression_type_string))
@@ -74,6 +79,7 @@ class TFRecordWriter(object):
@@write
@@close
"""
+
# TODO(josh11b): Support appending?
def __init__(self, path, options=None):
"""Opens file `path` and creates a `TFRecordWriter` writing to it.
@@ -85,10 +91,7 @@ class TFRecordWriter(object):
Raises:
IOError: If `path` cannot be opened for writing.
"""
- compression_type_string = ""
- if options:
- if options.compression_type == TFRecordCompressionType.ZLIB:
- compression_type_string = "ZLIB"
+ compression_type_string = options.get_type_as_string() if options else ""
self._writer = pywrap_tensorflow.PyRecordWriter_New(
compat.as_bytes(path), compat.as_bytes(compression_type_string))