aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/reader_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-13 18:55:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-13 20:03:08 -0700
commit0d13cdf6a1bb6a5d73b340f77338f2c3e1139bc3 (patch)
tree7d73fc2cd79cb6056b8e1fe152d2faf6f312a63b /tensorflow/python/kernel_tests/reader_ops_test.py
parent2e9964a2b199fe3e10e3f9b6fe28342ebd621a3b (diff)
Plumb the existing support for gzip compression in TensorFlow
out to the RecordWriter/Reader, and add tests to validate its compatibility with external gzip libraries. Change: 133081474
Diffstat (limited to 'tensorflow/python/kernel_tests/reader_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py54
1 files changed, 49 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index bdc73e4d51..dd97524357 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import gzip
import os
import threading
import zlib
@@ -27,6 +28,11 @@ import zlib
import six
import tensorflow as tf
+# pylint: disable=invalid-name
+TFRecordCompressionType = tf.python_io.TFRecordCompressionType
+# pylint: enable=invalid-name
+
+
# Edgar Allan Poe's 'Eldorado'
_TEXT = b"""Gaily bedight,
A gallant knight,
@@ -461,7 +467,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase):
fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
filenames.append(fn)
options = tf.python_io.TFRecordOptions(
- compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ compression_type=TFRecordCompressionType.ZLIB)
writer = tf.python_io.TFRecordWriter(fn, options=options)
for j in range(self._num_records):
writer.write(self._Record(i, j))
@@ -493,7 +499,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase):
files = self._CreateFiles()
with self.test_session() as sess:
options = tf.python_io.TFRecordOptions(
- compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ compression_type=TFRecordCompressionType.ZLIB)
reader = tf.TFRecordReader(name="test_reader", options=options)
queue = tf.FIFOQueue(99, [tf.string], shapes=())
key, value = reader.read(queue)
@@ -535,7 +541,7 @@ class TFRecordWriterZlibTest(tf.test.TestCase):
with self.test_session() as sess:
options = tf.python_io.TFRecordOptions(
- compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ compression_type=TFRecordCompressionType.ZLIB)
reader = tf.TFRecordReader(name="test_reader", options=options)
queue = tf.FIFOQueue(1, [tf.string], shapes=())
key, value = reader.read(queue)
@@ -576,6 +582,26 @@ class TFRecordWriterZlibTest(tf.test.TestCase):
self.assertEqual(actual, original)
+ def testGzipReadWrite(self):
+ """Verify that files produced are gzip compatible."""
+ original = [b"foo", b"bar"]
+ fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord")
+
+ # gzip compress the file and write compressed contents to file.
+ with open(fn, "rb") as f:
+ cdata = f.read()
+ gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz")
+ with gzip.GzipFile(gzfn, "wb") as f:
+ f.write(cdata)
+
+ actual = []
+ for r in tf.python_io.tf_record_iterator(
+ gzfn,
+ options=tf.python_io.TFRecordOptions(TFRecordCompressionType.GZIP)):
+ actual.append(r)
+ self.assertEqual(actual, original)
+
+
class TFRecordIteratorTest(tf.test.TestCase):
def setUp(self):
@@ -612,7 +638,7 @@ class TFRecordIteratorTest(tf.test.TestCase):
[self._Record(i) for i in range(self._num_records)],
"compressed_records")
options = tf.python_io.TFRecordOptions(
- compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ compression_type=TFRecordCompressionType.ZLIB)
reader = tf.python_io.tf_record_iterator(fn, options)
for i in range(self._num_records):
record = next(reader)
@@ -631,7 +657,7 @@ class TFRecordIteratorTest(tf.test.TestCase):
actual.append(r)
self.assertEqual(actual, original)
- def testZlibReadWriteLarge(self):
+ def testWriteZlibReadLarge(self):
"""Verify compression for large records is zlib library compatible."""
# Make it large (about 5MB)
original = [_TEXT * 10240]
@@ -643,6 +669,24 @@ class TFRecordIteratorTest(tf.test.TestCase):
actual.append(r)
self.assertEqual(actual, original)
+ def testWriteGzipRead(self):
+ original = [b"foo", b"bar"]
+ fn = self._WriteCompressedRecordsToFile(
+ original,
+ "write_gzip_read.tfrecord.gz",
+ compression_type=TFRecordCompressionType.GZIP)
+
+ with gzip.GzipFile(fn, "rb") as f:
+ cdata = f.read()
+ zfn = os.path.join(self.get_temp_dir(), "tf_record")
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+
+ actual = []
+ for r in tf.python_io.tf_record_iterator(zfn):
+ actual.append(r)
+ self.assertEqual(actual, original)
+
class AsyncReaderTest(tf.test.TestCase):