diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-05 11:58:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-05 12:02:31 -0700 |
commit | 62a70dd873bc8488b10df5ad55254119173a5d0c (patch) | |
tree | 6841cb19c6e0e894a847b72bb612f8ea4b148a41 | |
parent | b1fd2ef4d02719cd929fa574796b2c080a21a9ee (diff) |
Extend and refactor reader_ops_test
PiperOrigin-RevId: 199335030
-rw-r--r-- | tensorflow/python/kernel_tests/reader_ops_test.py | 352 |
1 files changed, 163 insertions, 189 deletions
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index 82a27eebee..7be473a5e7 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -77,6 +77,69 @@ _TEXT = b"""Gaily bedight, """ +class TFCompressionTestCase(test.TestCase): + + def setUp(self): + super(TFCompressionTestCase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + def _Record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _CreateFiles(self, options=None, prefix=""): + filenames = [] + for i in range(self._num_files): + name = prefix + "tfrecord.%d.txt" % i + records = [self._Record(i, j) for j in range(self._num_records)] + fn = self._WriteRecordsToFile(records, name, options) + filenames.append(fn) + return filenames + + def _WriteRecordsToFile(self, records, name="tfrecord", options=None): + fn = os.path.join(self.get_temp_dir(), name) + with tf_record.TFRecordWriter(fn, options=options) as writer: + for r in records: + writer.write(r) + return fn + + def _ZlibCompressFile(self, infile, name="tfrecord.z"): + # zlib compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = zlib.compress(f.read()) + + zfn = os.path.join(self.get_temp_dir(), name) + with open(zfn, "wb") as f: + f.write(cdata) + return zfn + + def _GzipCompressFile(self, infile, name="tfrecord.gz"): + # gzip compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = f.read() + + gzfn = os.path.join(self.get_temp_dir(), name) + with gzip.GzipFile(gzfn, "wb") as f: + f.write(cdata) + return gzfn + + def _ZlibDecompressFile(self, infile, name="tfrecord"): + with open(infile, "rb") as f: + cdata = zlib.decompress(f.read()) + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + def _GzipDecompressFile(self, infile, name="tfrecord"): + with gzip.GzipFile(infile, "rb") as f: + cdata = f.read() + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + class IdentityReaderTest(test.TestCase): def _ExpectRead(self, sess, key, value, expected): @@ -348,7 +411,7 @@ class TextLineReaderTest(test.TestCase): k, v = sess.run([key, value]) -class FixedLengthRecordReaderTest(test.TestCase): +class FixedLengthRecordReaderTest(TFCompressionTestCase): def setUp(self): super(FixedLengthRecordReaderTest, self).setUp() @@ -407,40 +470,18 @@ class FixedLengthRecordReaderTest(test.TestCase): # gap_bytes=hop_bytes-record_bytes def _CreateGzipFiles(self, num_records, gap_bytes): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with gzip.GzipFile(fn, "wb") as f: - f.write(b"H" * self._header_bytes) - if num_records > 0: - f.write(self._Record(i, 0)) - for j in range(1, num_records): - if gap_bytes > 0: - f.write(b"G" * gap_bytes) - f.write(self._Record(i, j)) - f.write(b"F" * self._footer_bytes) + filenames = self._CreateFiles(num_records, gap_bytes) + for fn in filenames: + # compress inplace. + self._GzipCompressFile(fn, fn) return filenames # gap_bytes=hop_bytes-record_bytes def _CreateZlibFiles(self, num_records, gap_bytes): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with open(fn + ".tmp", "wb") as f: - f.write(b"H" * self._header_bytes) - if num_records > 0: - f.write(self._Record(i, 0)) - for j in range(1, num_records): - if gap_bytes > 0: - f.write(b"G" * gap_bytes) - f.write(self._Record(i, j)) - f.write(b"F" * self._footer_bytes) - with open(fn + ".tmp", "rb") as f: - cdata = zlib.compress(f.read()) - with open(fn, "wb") as zf: - zf.write(cdata) + filenames = self._CreateFiles(num_records, gap_bytes) + for fn in filenames: + # compress inplace. + self._ZlibCompressFile(fn, fn) return filenames def _CreateGzipOverlappedRecordFiles(self, num_overlapped_records): @@ -477,10 +518,7 @@ class FixedLengthRecordReaderTest(test.TestCase): ]) f.write(compat.as_bytes(all_records_str)) f.write(b"F" * self._footer_bytes) - with open(fn + ".tmp", "rb") as f: - cdata = zlib.compress(f.read()) - with open(fn, "wb") as zf: - zf.write(cdata) + self._ZlibCompressFile(fn + ".tmp", fn) return filenames # gap_bytes=hop_bytes-record_bytes @@ -529,7 +567,6 @@ class FixedLengthRecordReaderTest(test.TestCase): for i in range(self._num_files): for j in range(num_overlapped_records): k, v = sess.run([key, value]) - print(v) self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k)) self.assertAllEqual(self._OverlappedRecord(i, j), v) @@ -579,25 +616,10 @@ class FixedLengthRecordReaderTest(test.TestCase): files, num_overlapped_records, encoding="ZLIB") -class TFRecordReaderTest(test.TestCase): +class TFRecordReaderTest(TFCompressionTestCase): def setUp(self): super(TFRecordReaderTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - - def _Record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) - - def _CreateFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = tf_record.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._Record(i, j)) - return filenames def testOneEpoch(self): files = self._CreateFiles() @@ -647,107 +669,106 @@ class TFRecordReaderTest(test.TestCase): self.assertEqual(self._num_files * self._num_records, num_v) def testReadZlibFiles(self): - files = self._CreateFiles() - zlib_files = [] - for i, fn in enumerate(files): - with open(fn, "rb") as f: - cdata = zlib.compress(f.read()) - - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) - with open(zfn, "wb") as f: - f.write(cdata) - zlib_files.append(zfn) + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + files = self._CreateFiles(options) with self.test_session() as sess: - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) - queue.enqueue_many([zlib_files]).run() + queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % zlib_files[i])) + self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) def testReadGzipFiles(self): - files = self._CreateFiles() - gzip_files = [] - for i, fn in enumerate(files): - with open(fn, "rb") as f: - cdata = f.read() - - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) - with gzip.GzipFile(zfn, "wb") as f: - f.write(cdata) - gzip_files.append(zfn) + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + files = self._CreateFiles(options) with self.test_session() as sess: - options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) - queue.enqueue_many([gzip_files]).run() + queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i])) + self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) -class TFRecordWriterZlibTest(test.TestCase): +class TFRecordWriterTest(TFCompressionTestCase): def setUp(self): - super(TFRecordWriterZlibTest, self).setUp() - self._num_files = 2 - self._num_records = 7 + super(TFRecordWriterTest, self).setUp() + + def _AssertFilesEqual(self, a, b, equal): + for an, bn in zip(a, b): + with open(an, "rb") as af, open(bn, "rb") as bf: + if equal: + self.assertEqual(af.read(), bf.read()) + else: + self.assertNotEqual(af.read(), bf.read()) + + def testWriteReadZLibFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + zlib_files = [ + self._ZlibCompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, zlib_files, False) - def _Record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + compressed_files = self._CreateFiles(options, prefix="compressed") + self._AssertFilesEqual(compressed_files, zlib_files, True) - def _CreateFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) - writer = tf_record.TFRecordWriter(fn, options=options) - for j in range(self._num_records): - writer.write(self._Record(i, j)) - writer.close() - del writer + # Decompress compress and verify same. + uncompressed_files = [ + self._ZlibDecompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) + + def testWriteReadGzipFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + gzip_files = [ + self._GzipCompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, gzip_files, False) - return filenames + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + compressed_files = self._CreateFiles(options, prefix="compressed") - def _WriteRecordsToFile(self, records, name="tf_record"): - fn = os.path.join(self.get_temp_dir(), name) - writer = tf_record.TFRecordWriter(fn, options=None) - for r in records: - writer.write(r) - writer.close() - del writer - return fn + # Note: Gzips written by TFRecordWriter add 'tfrecord_0' so + # compressed_files can't be compared with gzip_files - def _ZlibCompressFile(self, infile, name="tfrecord.z"): - # zlib compress the file and write compressed contents to file. - with open(infile, "rb") as f: - cdata = zlib.compress(f.read()) + # Decompress compress and verify same. + uncompressed_files = [ + self._GzipDecompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) - zfn = os.path.join(self.get_temp_dir(), name) - with open(zfn, "wb") as f: - f.write(cdata) - return zfn + +class TFRecordWriterZlibTest(TFCompressionTestCase): def testOneEpoch(self): - files = self._CreateFiles() + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + files = self._CreateFiles(options) with self.test_session() as sess: - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) @@ -788,8 +809,7 @@ class TFRecordWriterZlibTest(test.TestCase): h.write(output) with self.test_session() as sess: - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=()) key, value = reader.read(queue) @@ -808,9 +828,7 @@ class TFRecordWriterZlibTest(test.TestCase): # read the compressed contents and verify. actual = [] for r in tf_record.tf_record_iterator( - zfn, - options=tf_record.TFRecordOptions( - tf_record.TFRecordCompressionType.ZLIB)): + zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)): actual.append(r) self.assertEqual(actual, original) @@ -822,12 +840,9 @@ class TFRecordWriterZlibTest(test.TestCase): fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord") zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z") - # read the compressed contents and verify. actual = [] for r in tf_record.tf_record_iterator( - zfn, - options=tf_record.TFRecordOptions( - tf_record.TFRecordCompressionType.ZLIB)): + zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)): actual.append(r) self.assertEqual(actual, original) @@ -835,13 +850,7 @@ class TFRecordWriterZlibTest(test.TestCase): """Verify that files produced are gzip compatible.""" original = [b"foo", b"bar"] fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") - - # gzip compress the file and write compressed contents to file. - with open(fn, "rb") as f: - cdata = f.read() - gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz") - with gzip.GzipFile(gzfn, "wb") as f: - f.write(cdata) + gzfn = self._GzipCompressFile(fn, "tfrecord.gz") actual = [] for r in tf_record.tf_record_iterator( @@ -850,89 +859,54 @@ class TFRecordWriterZlibTest(test.TestCase): self.assertEqual(actual, original) -class TFRecordIteratorTest(test.TestCase): +class TFRecordIteratorTest(TFCompressionTestCase): def setUp(self): super(TFRecordIteratorTest, self).setUp() self._num_records = 7 - def _Record(self, r): - return compat.as_bytes("Record %d" % r) - - def _WriteCompressedRecordsToFile( - self, - records, - name="tfrecord.z", - compression_type=tf_record.TFRecordCompressionType.ZLIB): - fn = os.path.join(self.get_temp_dir(), name) - options = tf_record.TFRecordOptions(compression_type=compression_type) - writer = tf_record.TFRecordWriter(fn, options=options) - for r in records: - writer.write(r) - writer.close() - del writer - return fn - - def _ZlibDecompressFile(self, infile, name="tfrecord", wbits=zlib.MAX_WBITS): - with open(infile, "rb") as f: - cdata = zlib.decompress(f.read(), wbits) - zfn = os.path.join(self.get_temp_dir(), name) - with open(zfn, "wb") as f: - f.write(cdata) - return zfn - def testIterator(self): - fn = self._WriteCompressedRecordsToFile( - [self._Record(i) for i in range(self._num_records)], - "compressed_records") - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) + records = [self._Record(0, i) for i in range(self._num_records)] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(records, "compressed_records", options) + reader = tf_record.tf_record_iterator(fn, options) - for i in range(self._num_records): + for expected in records: record = next(reader) - self.assertAllEqual(self._Record(i), record) + self.assertAllEqual(expected, record) with self.assertRaises(StopIteration): record = next(reader) def testWriteZlibRead(self): """Verify compression with TFRecordWriter is zlib library compatible.""" original = [b"foo", b"bar"] - fn = self._WriteCompressedRecordsToFile(original, - "write_zlib_read.tfrecord.z") + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read.tfrecord.z", + options) + zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord") - actual = [] - for r in tf_record.tf_record_iterator(zfn): - actual.append(r) + actual = list(tf_record.tf_record_iterator(zfn)) self.assertEqual(actual, original) def testWriteZlibReadLarge(self): """Verify compression for large records is zlib library compatible.""" # Make it large (about 5MB) original = [_TEXT * 10240] - fn = self._WriteCompressedRecordsToFile(original, - "write_zlib_read_large.tfrecord.z") - zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tf_record") - actual = [] - for r in tf_record.tf_record_iterator(zfn): - actual.append(r) + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read_large.tfrecord.z", + options) + zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tfrecord") + actual = list(tf_record.tf_record_iterator(zfn)) self.assertEqual(actual, original) def testWriteGzipRead(self): original = [b"foo", b"bar"] - fn = self._WriteCompressedRecordsToFile( - original, - "write_gzip_read.tfrecord.gz", - compression_type=TFRecordCompressionType.GZIP) - - with gzip.GzipFile(fn, "rb") as f: - cdata = f.read() - zfn = os.path.join(self.get_temp_dir(), "tf_record") - with open(zfn, "wb") as f: - f.write(cdata) + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + fn = self._WriteRecordsToFile(original, "write_gzip_read.tfrecord.gz", + options) - actual = [] - for r in tf_record.tf_record_iterator(zfn): - actual.append(r) + gzfn = self._GzipDecompressFile(fn, "write_gzip_read.tfrecord") + actual = list(tf_record.tf_record_iterator(gzfn)) self.assertEqual(actual, original) def testBadFile(self): |