aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-05 11:58:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 12:02:31 -0700
commit62a70dd873bc8488b10df5ad55254119173a5d0c (patch)
tree6841cb19c6e0e894a847b72bb612f8ea4b148a41
parentb1fd2ef4d02719cd929fa574796b2c080a21a9ee (diff)
Extend and refactor reader_ops_test
PiperOrigin-RevId: 199335030
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py352
1 files changed, 163 insertions, 189 deletions
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 82a27eebee..7be473a5e7 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -77,6 +77,69 @@ _TEXT = b"""Gaily bedight,
"""
+class TFCompressionTestCase(test.TestCase):
+
+ def setUp(self):
+ super(TFCompressionTestCase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+
+ def _Record(self, f, r):
+ return compat.as_bytes("Record %d of file %d" % (r, f))
+
+ def _CreateFiles(self, options=None, prefix=""):
+ filenames = []
+ for i in range(self._num_files):
+ name = prefix + "tfrecord.%d.txt" % i
+ records = [self._Record(i, j) for j in range(self._num_records)]
+ fn = self._WriteRecordsToFile(records, name, options)
+ filenames.append(fn)
+ return filenames
+
+ def _WriteRecordsToFile(self, records, name="tfrecord", options=None):
+ fn = os.path.join(self.get_temp_dir(), name)
+ with tf_record.TFRecordWriter(fn, options=options) as writer:
+ for r in records:
+ writer.write(r)
+ return fn
+
+ def _ZlibCompressFile(self, infile, name="tfrecord.z"):
+ # zlib compress the file and write compressed contents to file.
+ with open(infile, "rb") as f:
+ cdata = zlib.compress(f.read())
+
+ zfn = os.path.join(self.get_temp_dir(), name)
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+ return zfn
+
+ def _GzipCompressFile(self, infile, name="tfrecord.gz"):
+ # gzip compress the file and write compressed contents to file.
+ with open(infile, "rb") as f:
+ cdata = f.read()
+
+ gzfn = os.path.join(self.get_temp_dir(), name)
+ with gzip.GzipFile(gzfn, "wb") as f:
+ f.write(cdata)
+ return gzfn
+
+ def _ZlibDecompressFile(self, infile, name="tfrecord"):
+ with open(infile, "rb") as f:
+ cdata = zlib.decompress(f.read())
+ fn = os.path.join(self.get_temp_dir(), name)
+ with open(fn, "wb") as f:
+ f.write(cdata)
+ return fn
+
+ def _GzipDecompressFile(self, infile, name="tfrecord"):
+ with gzip.GzipFile(infile, "rb") as f:
+ cdata = f.read()
+ fn = os.path.join(self.get_temp_dir(), name)
+ with open(fn, "wb") as f:
+ f.write(cdata)
+ return fn
+
+
class IdentityReaderTest(test.TestCase):
def _ExpectRead(self, sess, key, value, expected):
@@ -348,7 +411,7 @@ class TextLineReaderTest(test.TestCase):
k, v = sess.run([key, value])
-class FixedLengthRecordReaderTest(test.TestCase):
+class FixedLengthRecordReaderTest(TFCompressionTestCase):
def setUp(self):
super(FixedLengthRecordReaderTest, self).setUp()
@@ -407,40 +470,18 @@ class FixedLengthRecordReaderTest(test.TestCase):
# gap_bytes=hop_bytes-record_bytes
def _CreateGzipFiles(self, num_records, gap_bytes):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
- filenames.append(fn)
- with gzip.GzipFile(fn, "wb") as f:
- f.write(b"H" * self._header_bytes)
- if num_records > 0:
- f.write(self._Record(i, 0))
- for j in range(1, num_records):
- if gap_bytes > 0:
- f.write(b"G" * gap_bytes)
- f.write(self._Record(i, j))
- f.write(b"F" * self._footer_bytes)
+ filenames = self._CreateFiles(num_records, gap_bytes)
+ for fn in filenames:
+ # compress inplace.
+ self._GzipCompressFile(fn, fn)
return filenames
# gap_bytes=hop_bytes-record_bytes
def _CreateZlibFiles(self, num_records, gap_bytes):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
- filenames.append(fn)
- with open(fn + ".tmp", "wb") as f:
- f.write(b"H" * self._header_bytes)
- if num_records > 0:
- f.write(self._Record(i, 0))
- for j in range(1, num_records):
- if gap_bytes > 0:
- f.write(b"G" * gap_bytes)
- f.write(self._Record(i, j))
- f.write(b"F" * self._footer_bytes)
- with open(fn + ".tmp", "rb") as f:
- cdata = zlib.compress(f.read())
- with open(fn, "wb") as zf:
- zf.write(cdata)
+ filenames = self._CreateFiles(num_records, gap_bytes)
+ for fn in filenames:
+ # compress inplace.
+ self._ZlibCompressFile(fn, fn)
return filenames
def _CreateGzipOverlappedRecordFiles(self, num_overlapped_records):
@@ -477,10 +518,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
])
f.write(compat.as_bytes(all_records_str))
f.write(b"F" * self._footer_bytes)
- with open(fn + ".tmp", "rb") as f:
- cdata = zlib.compress(f.read())
- with open(fn, "wb") as zf:
- zf.write(cdata)
+ self._ZlibCompressFile(fn + ".tmp", fn)
return filenames
# gap_bytes=hop_bytes-record_bytes
@@ -529,7 +567,6 @@ class FixedLengthRecordReaderTest(test.TestCase):
for i in range(self._num_files):
for j in range(num_overlapped_records):
k, v = sess.run([key, value])
- print(v)
self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
self.assertAllEqual(self._OverlappedRecord(i, j), v)
@@ -579,25 +616,10 @@ class FixedLengthRecordReaderTest(test.TestCase):
files, num_overlapped_records, encoding="ZLIB")
-class TFRecordReaderTest(test.TestCase):
+class TFRecordReaderTest(TFCompressionTestCase):
def setUp(self):
super(TFRecordReaderTest, self).setUp()
- self._num_files = 2
- self._num_records = 7
-
- def _Record(self, f, r):
- return compat.as_bytes("Record %d of file %d" % (r, f))
-
- def _CreateFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = tf_record.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._Record(i, j))
- return filenames
def testOneEpoch(self):
files = self._CreateFiles()
@@ -647,107 +669,106 @@ class TFRecordReaderTest(test.TestCase):
self.assertEqual(self._num_files * self._num_records, num_v)
def testReadZlibFiles(self):
- files = self._CreateFiles()
- zlib_files = []
- for i, fn in enumerate(files):
- with open(fn, "rb") as f:
- cdata = zlib.compress(f.read())
-
- zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
- with open(zfn, "wb") as f:
- f.write(cdata)
- zlib_files.append(zfn)
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ files = self._CreateFiles(options)
with self.test_session() as sess:
- options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
- queue.enqueue_many([zlib_files]).run()
+ queue.enqueue_many([files]).run()
queue.close().run()
for i in range(self._num_files):
for j in range(self._num_records):
k, v = sess.run([key, value])
- self.assertTrue(compat.as_text(k).startswith("%s:" % zlib_files[i]))
+ self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
self.assertAllEqual(self._Record(i, j), v)
def testReadGzipFiles(self):
- files = self._CreateFiles()
- gzip_files = []
- for i, fn in enumerate(files):
- with open(fn, "rb") as f:
- cdata = f.read()
-
- zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
- with gzip.GzipFile(zfn, "wb") as f:
- f.write(cdata)
- gzip_files.append(zfn)
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
+ files = self._CreateFiles(options)
with self.test_session() as sess:
- options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
- queue.enqueue_many([gzip_files]).run()
+ queue.enqueue_many([files]).run()
queue.close().run()
for i in range(self._num_files):
for j in range(self._num_records):
k, v = sess.run([key, value])
- self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i]))
+ self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
self.assertAllEqual(self._Record(i, j), v)
-class TFRecordWriterZlibTest(test.TestCase):
+class TFRecordWriterTest(TFCompressionTestCase):
def setUp(self):
- super(TFRecordWriterZlibTest, self).setUp()
- self._num_files = 2
- self._num_records = 7
+ super(TFRecordWriterTest, self).setUp()
+
+ def _AssertFilesEqual(self, a, b, equal):
+ for an, bn in zip(a, b):
+ with open(an, "rb") as af, open(bn, "rb") as bf:
+ if equal:
+ self.assertEqual(af.read(), bf.read())
+ else:
+ self.assertNotEqual(af.read(), bf.read())
+
+ def testWriteReadZLibFiles(self):
+ # Write uncompressed then compress manually.
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE)
+ files = self._CreateFiles(options, prefix="uncompressed")
+ zlib_files = [
+ self._ZlibCompressFile(fn, "tfrecord_%s.z" % i)
+ for i, fn in enumerate(files)
+ ]
+ self._AssertFilesEqual(files, zlib_files, False)
- def _Record(self, f, r):
- return compat.as_bytes("Record %d of file %d" % (r, f))
+ # Now write compressd and verify same.
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ compressed_files = self._CreateFiles(options, prefix="compressed")
+ self._AssertFilesEqual(compressed_files, zlib_files, True)
- def _CreateFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- options = tf_record.TFRecordOptions(
- compression_type=TFRecordCompressionType.ZLIB)
- writer = tf_record.TFRecordWriter(fn, options=options)
- for j in range(self._num_records):
- writer.write(self._Record(i, j))
- writer.close()
- del writer
+ # Decompress compress and verify same.
+ uncompressed_files = [
+ self._ZlibDecompressFile(fn, "tfrecord_%s.z" % i)
+ for i, fn in enumerate(compressed_files)
+ ]
+ self._AssertFilesEqual(uncompressed_files, files, True)
+
+ def testWriteReadGzipFiles(self):
+ # Write uncompressed then compress manually.
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE)
+ files = self._CreateFiles(options, prefix="uncompressed")
+ gzip_files = [
+ self._GzipCompressFile(fn, "tfrecord_%s.gz" % i)
+ for i, fn in enumerate(files)
+ ]
+ self._AssertFilesEqual(files, gzip_files, False)
- return filenames
+ # Now write compressd and verify same.
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
+ compressed_files = self._CreateFiles(options, prefix="compressed")
- def _WriteRecordsToFile(self, records, name="tf_record"):
- fn = os.path.join(self.get_temp_dir(), name)
- writer = tf_record.TFRecordWriter(fn, options=None)
- for r in records:
- writer.write(r)
- writer.close()
- del writer
- return fn
+ # Note: Gzips written by TFRecordWriter add 'tfrecord_0' so
+ # compressed_files can't be compared with gzip_files
- def _ZlibCompressFile(self, infile, name="tfrecord.z"):
- # zlib compress the file and write compressed contents to file.
- with open(infile, "rb") as f:
- cdata = zlib.compress(f.read())
+ # Decompress compress and verify same.
+ uncompressed_files = [
+ self._GzipDecompressFile(fn, "tfrecord_%s.gz" % i)
+ for i, fn in enumerate(compressed_files)
+ ]
+ self._AssertFilesEqual(uncompressed_files, files, True)
- zfn = os.path.join(self.get_temp_dir(), name)
- with open(zfn, "wb") as f:
- f.write(cdata)
- return zfn
+
+class TFRecordWriterZlibTest(TFCompressionTestCase):
def testOneEpoch(self):
- files = self._CreateFiles()
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ files = self._CreateFiles(options)
with self.test_session() as sess:
- options = tf_record.TFRecordOptions(
- compression_type=TFRecordCompressionType.ZLIB)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -788,8 +809,7 @@ class TFRecordWriterZlibTest(test.TestCase):
h.write(output)
with self.test_session() as sess:
- options = tf_record.TFRecordOptions(
- compression_type=TFRecordCompressionType.ZLIB)
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -808,9 +828,7 @@ class TFRecordWriterZlibTest(test.TestCase):
# read the compressed contents and verify.
actual = []
for r in tf_record.tf_record_iterator(
- zfn,
- options=tf_record.TFRecordOptions(
- tf_record.TFRecordCompressionType.ZLIB)):
+ zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)):
actual.append(r)
self.assertEqual(actual, original)
@@ -822,12 +840,9 @@ class TFRecordWriterZlibTest(test.TestCase):
fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord")
zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z")
- # read the compressed contents and verify.
actual = []
for r in tf_record.tf_record_iterator(
- zfn,
- options=tf_record.TFRecordOptions(
- tf_record.TFRecordCompressionType.ZLIB)):
+ zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)):
actual.append(r)
self.assertEqual(actual, original)
@@ -835,13 +850,7 @@ class TFRecordWriterZlibTest(test.TestCase):
"""Verify that files produced are gzip compatible."""
original = [b"foo", b"bar"]
fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord")
-
- # gzip compress the file and write compressed contents to file.
- with open(fn, "rb") as f:
- cdata = f.read()
- gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz")
- with gzip.GzipFile(gzfn, "wb") as f:
- f.write(cdata)
+ gzfn = self._GzipCompressFile(fn, "tfrecord.gz")
actual = []
for r in tf_record.tf_record_iterator(
@@ -850,89 +859,54 @@ class TFRecordWriterZlibTest(test.TestCase):
self.assertEqual(actual, original)
-class TFRecordIteratorTest(test.TestCase):
+class TFRecordIteratorTest(TFCompressionTestCase):
def setUp(self):
super(TFRecordIteratorTest, self).setUp()
self._num_records = 7
- def _Record(self, r):
- return compat.as_bytes("Record %d" % r)
-
- def _WriteCompressedRecordsToFile(
- self,
- records,
- name="tfrecord.z",
- compression_type=tf_record.TFRecordCompressionType.ZLIB):
- fn = os.path.join(self.get_temp_dir(), name)
- options = tf_record.TFRecordOptions(compression_type=compression_type)
- writer = tf_record.TFRecordWriter(fn, options=options)
- for r in records:
- writer.write(r)
- writer.close()
- del writer
- return fn
-
- def _ZlibDecompressFile(self, infile, name="tfrecord", wbits=zlib.MAX_WBITS):
- with open(infile, "rb") as f:
- cdata = zlib.decompress(f.read(), wbits)
- zfn = os.path.join(self.get_temp_dir(), name)
- with open(zfn, "wb") as f:
- f.write(cdata)
- return zfn
-
def testIterator(self):
- fn = self._WriteCompressedRecordsToFile(
- [self._Record(i) for i in range(self._num_records)],
- "compressed_records")
- options = tf_record.TFRecordOptions(
- compression_type=TFRecordCompressionType.ZLIB)
+ records = [self._Record(0, i) for i in range(self._num_records)]
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ fn = self._WriteRecordsToFile(records, "compressed_records", options)
+
reader = tf_record.tf_record_iterator(fn, options)
- for i in range(self._num_records):
+ for expected in records:
record = next(reader)
- self.assertAllEqual(self._Record(i), record)
+ self.assertAllEqual(expected, record)
with self.assertRaises(StopIteration):
record = next(reader)
def testWriteZlibRead(self):
"""Verify compression with TFRecordWriter is zlib library compatible."""
original = [b"foo", b"bar"]
- fn = self._WriteCompressedRecordsToFile(original,
- "write_zlib_read.tfrecord.z")
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ fn = self._WriteRecordsToFile(original, "write_zlib_read.tfrecord.z",
+ options)
+
zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord")
- actual = []
- for r in tf_record.tf_record_iterator(zfn):
- actual.append(r)
+ actual = list(tf_record.tf_record_iterator(zfn))
self.assertEqual(actual, original)
def testWriteZlibReadLarge(self):
"""Verify compression for large records is zlib library compatible."""
# Make it large (about 5MB)
original = [_TEXT * 10240]
- fn = self._WriteCompressedRecordsToFile(original,
- "write_zlib_read_large.tfrecord.z")
- zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tf_record")
- actual = []
- for r in tf_record.tf_record_iterator(zfn):
- actual.append(r)
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ fn = self._WriteRecordsToFile(original, "write_zlib_read_large.tfrecord.z",
+ options)
+ zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tfrecord")
+ actual = list(tf_record.tf_record_iterator(zfn))
self.assertEqual(actual, original)
def testWriteGzipRead(self):
original = [b"foo", b"bar"]
- fn = self._WriteCompressedRecordsToFile(
- original,
- "write_gzip_read.tfrecord.gz",
- compression_type=TFRecordCompressionType.GZIP)
-
- with gzip.GzipFile(fn, "rb") as f:
- cdata = f.read()
- zfn = os.path.join(self.get_temp_dir(), "tf_record")
- with open(zfn, "wb") as f:
- f.write(cdata)
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
+ fn = self._WriteRecordsToFile(original, "write_gzip_read.tfrecord.gz",
+ options)
- actual = []
- for r in tf_record.tf_record_iterator(zfn):
- actual.append(r)
+ gzfn = self._GzipDecompressFile(fn, "write_gzip_read.tfrecord")
+ actual = list(tf_record.tf_record_iterator(gzfn))
self.assertEqual(actual, original)
def testBadFile(self):