diff options
author | 2016-06-16 14:40:17 -0800 | |
---|---|---|
committer | 2016-06-16 15:48:05 -0700 | |
commit | 90bad51f97cc32dc9bc8e2ffb3e81b65cec37dfb (patch) | |
tree | 22621e351e2a91323e152f17039162872d07798e | |
parent | 598867b6a207be402cf3555697a212825b81a882 (diff) |
Add support to TFRecordReader and writer to read/write from zlib compressed files.
Change: 125110928
23 files changed, 422 insertions, 59 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index f1408d44d6..77db4004c7 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -633,7 +633,6 @@ filegroup( "**/*testlib*", "**/*main.cc", "graph/dot.*", - "lib/io/zlib*", "lib/jpeg/**/*", "lib/png/**/*", "util/checkpoint_reader.*", @@ -688,6 +687,7 @@ cc_library( name = "android_tensorflow_lib", srcs = if_android([":android_op_registrations_and_gradients"]), copts = tf_copts(), + linkopts = ["-lz"], tags = [ "manual", "notap", @@ -1270,6 +1270,7 @@ tf_cc_tests( "lib/io/inputbuffer_test.cc", "lib/io/match_test.cc", "lib/io/path_test.cc", + "lib/io/record_reader_writer_test.cc", "lib/io/recordio_test.cc", "lib/io/table_test.cc", "lib/io/zlib_buffers_test.cc", diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc index 9604d44567..3cbb334141 100644 --- a/tensorflow/core/kernels/tf_record_reader_op.cc +++ b/tensorflow/core/kernels/tf_record_reader_op.cc @@ -27,17 +27,25 @@ namespace tensorflow { class TFRecordReader : public ReaderBase { public: - TFRecordReader(const string& node_name, Env* env) + TFRecordReader(const string& node_name, const string& compression_type, + Env* env) : ReaderBase(strings::StrCat("TFRecordReader '", node_name, "'")), env_(env), - offset_(0) {} + offset_(0), + compression_type_(compression_type) {} Status OnWorkStartedLocked() override { offset_ = 0; RandomAccessFile* file = nullptr; TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file)); file_.reset(file); - reader_.reset(new io::RecordReader(file)); + + io::RecordReaderOptions options; + if (compression_type_ == "ZLIB") { + options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; + } + + reader_.reset(new io::RecordReader(file, options)); return Status::OK(); } @@ -74,6 +82,7 @@ class TFRecordReader : public ReaderBase { uint64 offset_; std::unique_ptr<RandomAccessFile> file_; std::unique_ptr<io::RecordReader> reader_; + string compression_type_ = ""; }; class TFRecordReaderOp : public ReaderOpKernel { @@ -81,7 +90,13 @@ class TFRecordReaderOp : public ReaderOpKernel { explicit TFRecordReaderOp(OpKernelConstruction* context) : ReaderOpKernel(context) { Env* env = context->env(); - SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); }); + + string compression_type; + context->GetAttr("compression_type", &compression_type); + + SetReaderFactory([this, compression_type, env]() { + return new TFRecordReader(name(), compression_type, env); + }); } }; diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index 76011430de..eb194a14d4 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/lib/io/record_reader.h" #include <limits.h> + #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/crc32c.h" @@ -24,38 +25,76 @@ limitations under the License. namespace tensorflow { namespace io { -RecordReader::RecordReader(RandomAccessFile* file) : src_(file) {} +RecordReader::RecordReader(RandomAccessFile* file, + const RecordReaderOptions& options) + : src_(file), options_(options) { + if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) { + zlib_input_buffer_.reset(new ZlibInputBuffer( + src_, options.zlib_options.input_buffer_size, + options.zlib_options.output_buffer_size, options.zlib_options)); + } else if (options.compression_type == RecordReaderOptions::NONE) { + // Nothing to do. + } else { + LOG(FATAL) << "Unspecified compression type :" << options.compression_type; + } +} RecordReader::~RecordReader() {} // Read n+4 bytes from file, verify that checksum of first n bytes is // stored in the last 4 bytes and store the first n bytes in *result. // May use *storage as backing store. -static Status ReadChecksummed(RandomAccessFile* file, uint64 offset, size_t n, - StringPiece* result, string* storage) { +Status RecordReader::ReadChecksummed(uint64 offset, size_t n, + StringPiece* result, string* storage) { if (n >= SIZE_MAX - sizeof(uint32)) { return errors::DataLoss("record size too large"); } const size_t expected = n + sizeof(uint32); storage->resize(expected); - StringPiece data; - Status s = file->Read(offset, expected, &data, &(*storage)[0]); - if (!s.ok()) { - return s; - } - if (data.size() != expected) { - if (data.size() == 0) { - return errors::OutOfRange("eof"); - } else { - return errors::DataLoss("truncated record at ", offset); + + if (zlib_input_buffer_) { + // If we have a zlib compressed buffer, we assume that the + // file is being read sequentially, and we use the underlying + // implementation to read the data. + // + // No checks are done to validate that the file is being read + // sequentially. At some point the zlib input buffer may support + // seeking, possibly inefficiently. + TF_RETURN_IF_ERROR(zlib_input_buffer_->ReadNBytes(expected, storage)); + + if (storage->size() != expected) { + if (storage->size() == 0) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } } + + uint32 masked_crc = core::DecodeFixed32(storage->data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(storage->data(), n); + } else { + // This version supports reading from arbitrary offsets + // since we are accessing the random access file directly. + StringPiece data; + TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0])); + if (data.size() != expected) { + if (data.size() == 0) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } + } + uint32 masked_crc = core::DecodeFixed32(data.data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(data.data(), n); } - uint32 masked_crc = core::DecodeFixed32(data.data() + n); - if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) { - return errors::DataLoss("corrupted record at ", offset); - } - *result = StringPiece(data.data(), n); + return Status::OK(); } @@ -63,9 +102,9 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) { static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); static const size_t kFooterSize = sizeof(uint32); - // Read length + // Read header data. StringPiece lbuf; - Status s = ReadChecksummed(src_, *offset, sizeof(uint64), &lbuf, record); + Status s = ReadChecksummed(*offset, sizeof(uint64), &lbuf, record); if (!s.ok()) { return s; } @@ -73,19 +112,21 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) { // Read data StringPiece data; - s = ReadChecksummed(src_, *offset + kHeaderSize, length, &data, record); + s = ReadChecksummed(*offset + kHeaderSize, length, &data, record); if (!s.ok()) { if (errors::IsOutOfRange(s)) { s = errors::DataLoss("truncated record at ", *offset); } return s; } + if (record->data() != data.data()) { // RandomAccessFile placed the data in some other location. memmove(&(*record)[0], data.data(), data.size()); } record->resize(data.size()); + *offset += kHeaderSize + length + kFooterSize; return Status::OK(); } diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index f96d24aa31..b4c56451be 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -18,6 +18,9 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputbuffer.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -27,13 +30,23 @@ class RandomAccessFile; namespace io { +class RecordReaderOptions { + public: + enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 }; + CompressionType compression_type = NONE; + + // Options specific to zlib compression. + ZlibCompressionOptions zlib_options; +}; + class RecordReader { public: // Create a reader that will return log records from "*file". // "*file" must remain live while this Reader is in use. - explicit RecordReader(RandomAccessFile* file); + RecordReader(RandomAccessFile* file, + const RecordReaderOptions& options = RecordReaderOptions()); - ~RecordReader(); + virtual ~RecordReader(); // Read the record at "*offset" into *record and update *offset to // point to the offset of the next record. Returns OK on success, @@ -41,7 +54,12 @@ class RecordReader { Status ReadRecord(uint64* offset, string* record); private: + Status ReadChecksummed(uint64 offset, size_t n, StringPiece* result, + string* storage); + RandomAccessFile* src_; + RecordReaderOptions options_; + std::unique_ptr<ZlibInputBuffer> zlib_input_buffer_; TF_DISALLOW_COPY_AND_ASSIGN(RecordReader); }; diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc new file mode 100644 index 0000000000..ee8a7fa91a --- /dev/null +++ b/tensorflow/core/lib/io/record_reader_writer_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/io/record_writer.h" + +#include <vector> +#include "tensorflow/core/platform/env.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +static std::vector<int> BufferSizes() { + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 65536}; +} + +TEST(RecordReaderWriterTest, TestBasics) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/record_reader_writer_test"; + + for (auto buf_size : BufferSizes()) { + WritableFile* file; + TF_CHECK_OK(env->NewWritableFile(fname, &file)); + + { + io::RecordWriterOptions options; + options.zlib_options.output_buffer_size = buf_size; + io::RecordWriter writer(file, options); + writer.WriteRecord("abc"); + writer.WriteRecord("defg"); + TF_CHECK_OK(writer.Flush()); + } + delete file; + + RandomAccessFile* read_file; + { + // Read it back with the RecordReader. + TF_CHECK_OK(env->NewRandomAccessFile(fname, &read_file)); + io::RecordReaderOptions options; + options.zlib_options.input_buffer_size = buf_size; + io::RecordReader reader(read_file, options); + uint64 offset = 0; + string record; + TF_CHECK_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("abc", record); + TF_CHECK_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("defg", record); + } + + delete read_file; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index e42f1256a4..7993f6ca20 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -22,9 +22,28 @@ limitations under the License. namespace tensorflow { namespace io { -RecordWriter::RecordWriter(WritableFile* dest) : dest_(dest) {} +RecordWriter::RecordWriter(WritableFile* dest, + const RecordWriterOptions& options) + : dest_(dest), options_(options) { + if (options.compression_type == RecordWriterOptions::ZLIB_COMPRESSION) { + zlib_output_buffer_.reset(new ZlibOutputBuffer( + dest_, options.zlib_options.input_buffer_size, + options.zlib_options.output_buffer_size, options.zlib_options)); + } else if (options.compression_type == RecordWriterOptions::NONE) { + // Nothing to do + } else { + LOG(FATAL) << "Unspecified compression type :" << options.compression_type; + } +} -RecordWriter::~RecordWriter() {} +RecordWriter::~RecordWriter() { + if (zlib_output_buffer_) { + Status s = zlib_output_buffer_->Close(); + if (!s.ok()) { + LOG(ERROR) << "Could not finish writing file: " << s; + } + } +} static uint32 MaskedCrc(const char* data, size_t n) { return crc32c::Mask(crc32c::Value(data, n)); @@ -40,17 +59,19 @@ Status RecordWriter::WriteRecord(StringPiece data) { core::EncodeFixed64(header + 0, data.size()); core::EncodeFixed32(header + sizeof(uint64), MaskedCrc(header, sizeof(uint64))); - Status s = dest_->Append(StringPiece(header, sizeof(header))); - if (!s.ok()) { - return s; - } - s = dest_->Append(data); - if (!s.ok()) { - return s; - } char footer[sizeof(uint32)]; core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size())); - return dest_->Append(StringPiece(footer, sizeof(footer))); + + if (zlib_output_buffer_) { + TF_RETURN_IF_ERROR( + zlib_output_buffer_->Write(StringPiece(header, sizeof(header)))); + TF_RETURN_IF_ERROR(zlib_output_buffer_->Write(data)); + return zlib_output_buffer_->Write(StringPiece(footer, sizeof(footer))); + } else { + TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + TF_RETURN_IF_ERROR(dest_->Append(data)); + return dest_->Append(StringPiece(footer, sizeof(footer))); + } } } // namespace io diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 93b2a08714..2344df3b25 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_outputbuffer.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -27,19 +29,42 @@ class WritableFile; namespace io { +class RecordWriterOptions { + public: + enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 }; + CompressionType compression_type = NONE; + + // Options specific to zlib compression. + ZlibCompressionOptions zlib_options; +}; + class RecordWriter { public: // Create a writer that will append data to "*dest". // "*dest" must be initially empty. // "*dest" must remain live while this Writer is in use. - explicit RecordWriter(WritableFile* dest); + RecordWriter(WritableFile* dest, + const RecordWriterOptions& options = RecordWriterOptions()); ~RecordWriter(); Status WriteRecord(StringPiece slice); + // Flushes any buffered data held by underlying containers of the + // RecordWriter to the WritableFile. Does *not* flush the + // WritableFile. + Status Flush() { + if (zlib_output_buffer_) { + return zlib_output_buffer_->Flush(); + } + + return Status::OK(); + } + private: WritableFile* const dest_; + RecordWriterOptions options_; + std::unique_ptr<ZlibOutputBuffer> zlib_output_buffer_; TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter); }; diff --git a/tensorflow/core/lib/io/zlib_compression_options.h b/tensorflow/core/lib/io/zlib_compression_options.h index 49afea4e36..95af0ab9c9 100644 --- a/tensorflow/core/lib/io/zlib_compression_options.h +++ b/tensorflow/core/lib/io/zlib_compression_options.h @@ -16,7 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ #define TENSORFLOW_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +// TODO(srbs|vrv): Move to a platform/zlib.h file to centralize all +// platform-specific includes +#ifdef __ANDROID__ +#include "zlib.h" +#else #include <zlib.h> +#endif // __ANDROID__ namespace tensorflow { namespace io { @@ -28,6 +34,13 @@ class ZlibCompressionOptions { int8 flush_mode = Z_NO_FLUSH; + // Size of the buffer used for caching the data read from source file. + int64 input_buffer_size = 256 << 10; + + // Size of the sink buffer where the compressed/decompressed data produced by + // zlib is cached. + int64 output_buffer_size = 256 << 10; + // The window_bits parameter is the base two logarithm of the window size // (the size of the history buffer). Larger values of buffer size result in // better compression at the expense of memory usage. diff --git a/tensorflow/core/lib/io/zlib_inputbuffer.cc b/tensorflow/core/lib/io/zlib_inputbuffer.cc index 5f20fba0e6..b5224168f1 100644 --- a/tensorflow/core/lib/io/zlib_inputbuffer.cc +++ b/tensorflow/core/lib/io/zlib_inputbuffer.cc @@ -128,6 +128,7 @@ size_t ZlibInputBuffer::NumUnreadBytes() const { } Status ZlibInputBuffer::ReadNBytes(int64 bytes_to_read, string* result) { + result->clear(); // Read as many bytes as possible from cache. bytes_to_read -= ReadBytesFromCache(bytes_to_read, result); @@ -163,8 +164,8 @@ Status ZlibInputBuffer::ReadNBytes(int64 bytes_to_read, string* result) { Status ZlibInputBuffer::Inflate() { int error = inflate(z_stream_.get(), zlib_options_.flush_mode); if (error != Z_OK && error != Z_FINISH) { - string error_string = strings::StrCat("inflate() failed with error ", - std::to_string(error).c_str()); + string error_string = + strings::StrCat("inflate() failed with error ", error); if (z_stream_->msg != NULL) { strings::StrAppend(&error_string, ": ", z_stream_->msg); } diff --git a/tensorflow/core/lib/io/zlib_inputbuffer.h b/tensorflow/core/lib/io/zlib_inputbuffer.h index e21070134a..008d51876a 100644 --- a/tensorflow/core/lib/io/zlib_inputbuffer.h +++ b/tensorflow/core/lib/io/zlib_inputbuffer.h @@ -22,7 +22,14 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" + +// TODO(srbs|vrv): Move to a platform/zlib.h file to centralize all +// platform-specific includes +#ifdef __ANDROID__ +#include "zlib.h" +#else #include <zlib.h> +#endif // __ANDROID__ namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc index 6db9423c65..9493804bcb 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.cc +++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc @@ -200,8 +200,7 @@ Status ZlibOutputBuffer::Deflate(int flush) { (error == Z_STREAM_END && flush == Z_FINISH)) { return Status::OK(); } - string error_string = strings::StrCat("deflate() failed with error ", - std::to_string(error).c_str()); + string error_string = strings::StrCat("deflate() failed with error ", error); if (z_stream_->msg != NULL) { strings::StrAppend(&error_string, ": ", z_stream_->msg); } diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h index a5c6b16bd9..08455b63b5 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.h +++ b/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -22,7 +22,14 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" + +// TODO(srbs|vrv): Move to a platform/zlib.h file to centralize all +// platform-specific includes. +#ifdef __ANDROID__ +#include "zlib.h" +#else #include <zlib.h> +#endif // __ANDROID__ namespace tensorflow { namespace io { diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index 112a5278bf..10a0063299 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -217,6 +217,7 @@ REGISTER_OP("TFRecordReader") .Output("reader_handle: Ref(string)") .Attr("container: string = ''") .Attr("shared_name: string = ''") + .Attr("compression_type: string = ''") .SetIsStateful() .Doc(R"doc( A Reader that outputs the records from a TensorFlow Records file. diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc index 21dca84f9c..c1f359c491 100644 --- a/tensorflow/core/util/events_writer.cc +++ b/tensorflow/core/util/events_writer.cc @@ -113,6 +113,13 @@ void EventsWriter::WriteEvent(const Event& event) { bool EventsWriter::Flush() { if (num_outstanding_events_ == 0) return true; CHECK(recordio_file_.get() != NULL) << "Unexpected NULL file"; + + if (!recordio_writer_->Flush().ok()) { + LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to " + << filename_; + return false; + } + // The FileHasDisappeared() condition is necessary because // recordio_writer_->Sync() can return true even if the underlying // file has been deleted. EventWriter.FileDeletionBeforeWriting diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index a335344c83..8bcec44a61 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -410,6 +410,85 @@ class TFRecordReaderTest(tf.test.TestCase): self.assertEqual(self._num_files * self._num_records, num_v) +class TFRecordWriterZlibTest(tf.test.TestCase): + + def setUp(self): + super(TFRecordWriterZlibTest, self).setUp() + self._num_files = 2 + self._num_records = 7 + + def _Record(self, f, r): + return tf.compat.as_bytes("Record %d of file %d" % (r, f)) + + def _CreateFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + options = tf.python_io.TFRecordOptions( + compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + writer = tf.python_io.TFRecordWriter(fn, options=options) + for j in range(self._num_records): + writer.write(self._Record(i, j)) + writer.close() + del writer + + return filenames + + def testOneEpoch(self): + files = self._CreateFiles() + with self.test_session() as sess: + options = tf.python_io.TFRecordOptions( + compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + reader = tf.TFRecordReader(name="test_reader", options=options) + queue = tf.FIFOQueue(99, [tf.string], shapes=()) + key, value = reader.read(queue) + + queue.enqueue_many([files]).run() + queue.close().run() + for i in range(self._num_files): + for j in range(self._num_records): + k, v = sess.run([key, value]) + self.assertTrue(tf.compat.as_text(k).startswith("%s:" % files[i])) + self.assertAllEqual(self._Record(i, j), v) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + k, v = sess.run([key, value]) + + +class TFRecordIteratorTest(tf.test.TestCase): + + def setUp(self): + super(TFRecordIteratorTest, self).setUp() + self._num_records = 7 + + def _Record(self, r): + return tf.compat.as_bytes("Record %d" % r) + + def _CreateFile(self): + fn = os.path.join(self.get_temp_dir(), "tf_record.txt") + options = tf.python_io.TFRecordOptions( + compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + writer = tf.python_io.TFRecordWriter(fn, options=options) + for i in range(self._num_records): + writer.write(self._Record(i)) + writer.close() + del writer + return fn + + def testIterator(self): + fn = self._CreateFile() + options = tf.python_io.TFRecordOptions( + compression_type=tf.python_io.TFRecordCompressionType.ZLIB) + reader = tf.python_io.tf_record_iterator(fn, options) + for i in range(self._num_records): + record = next(reader) + self.assertAllEqual(self._Record(i), record) + with self.assertRaises(StopIteration): + record = next(reader) + + class AsyncReaderTest(tf.test.TestCase): def testNoDeadlockFromQueue(self): diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc index a02e3bdba7..e82f1fb223 100644 --- a/tensorflow/python/lib/io/py_record_reader.cc +++ b/tensorflow/python/lib/io/py_record_reader.cc @@ -28,8 +28,8 @@ namespace io { PyRecordReader::PyRecordReader() {} -PyRecordReader* PyRecordReader::New(const string& filename, - uint64 start_offset) { +PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset, + const string& compression_type_string) { RandomAccessFile* file; Status s = Env::Default()->NewRandomAccessFile(filename, &file); if (!s.ok()) { @@ -38,7 +38,12 @@ PyRecordReader* PyRecordReader::New(const string& filename, PyRecordReader* reader = new PyRecordReader; reader->offset_ = start_offset; reader->file_ = file; - reader->reader_ = new RecordReader(reader->file_); + + RecordReaderOptions options; + if (compression_type_string == "ZLIB") { + options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION; + } + reader->reader_ = new RecordReader(reader->file_, options); return reader; } diff --git a/tensorflow/python/lib/io/py_record_reader.h b/tensorflow/python/lib/io/py_record_reader.h index 5d9c7108b6..a72cf04c3c 100644 --- a/tensorflow/python/lib/io/py_record_reader.h +++ b/tensorflow/python/lib/io/py_record_reader.h @@ -33,7 +33,11 @@ class RecordReader; // by multiple threads. class PyRecordReader { public: - static PyRecordReader* New(const string& filename, uint64 start_offset); + // TODO(vrv): make this take a shared proto to configure + // the compression options. + static PyRecordReader* New(const string& filename, uint64 start_offset, + const string& compression_type_string); + ~PyRecordReader(); // Attempt to get the next record at "current_offset()". If diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc index 6050ce4c6b..3ff8a3bff2 100644 --- a/tensorflow/python/lib/io/py_record_writer.cc +++ b/tensorflow/python/lib/io/py_record_writer.cc @@ -25,7 +25,8 @@ namespace io { PyRecordWriter::PyRecordWriter() {} -PyRecordWriter* PyRecordWriter::New(const string& filename) { +PyRecordWriter* PyRecordWriter::New(const string& filename, + const string& compression_type_string) { WritableFile* file; Status s = Env::Default()->NewWritableFile(filename, &file); if (!s.ok()) { @@ -33,7 +34,12 @@ PyRecordWriter* PyRecordWriter::New(const string& filename) { } PyRecordWriter* writer = new PyRecordWriter; writer->file_ = file; - writer->writer_ = new RecordWriter(writer->file_); + + RecordWriterOptions options; + if (compression_type_string == "ZLIB") { + options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION; + } + writer->writer_ = new RecordWriter(writer->file_, options); return writer; } diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h index e99381f492..86e2b9e56f 100644 --- a/tensorflow/python/lib/io/py_record_writer.h +++ b/tensorflow/python/lib/io/py_record_writer.h @@ -33,7 +33,10 @@ class RecordWriter; // by multiple threads. class PyRecordWriter { public: - static PyRecordWriter* New(const string& filename); + // TODO(vrv): make this take a shared proto to configure + // the compression options. + static PyRecordWriter* New(const string& filename, + const string& compression_type_string); ~PyRecordWriter(); bool WriteRecord(tensorflow::StringPiece record); diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index f792d4d1d8..4694dee5f6 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -23,11 +23,25 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.util import compat -def tf_record_iterator(path): +class TFRecordCompressionType(object): + NONE = 0 + ZLIB = 1 + + +# NOTE(vrv): This will eventually be converted into a proto. to match +# the interface used by the C++ RecordWriter. +class TFRecordOptions(object): + + def __init__(self, compression_type): + self.compression_type = compression_type + + +def tf_record_iterator(path, options=None): """An iterator that read the records from a TFRecords file. Args: path: The path to the TFRecords file. + options: (optional) A TFRecordOptions object. Yields: Strings. @@ -35,7 +49,14 @@ def tf_record_iterator(path): Raises: IOError: If `path` cannot be opened for reading. """ - reader = pywrap_tensorflow.PyRecordReader_New(compat.as_bytes(path), 0) + compression_type_string = "" + if options: + if options.compression_type == TFRecordCompressionType.ZLIB: + compression_type_string = "ZLIB" + + reader = pywrap_tensorflow.PyRecordReader_New( + compat.as_bytes(path), 0, compat.as_bytes(compression_type_string)) + if reader is None: raise IOError("Could not open %s." % path) while reader.GetNext(): @@ -54,16 +75,23 @@ class TFRecordWriter(object): @@close """ # TODO(josh11b): Support appending? - def __init__(self, path): + def __init__(self, path, options=None): """Opens file `path` and creates a `TFRecordWriter` writing to it. Args: path: The path to the TFRecords file. + options: (optional) A TFRecordOptions object. Raises: IOError: If `path` cannot be opened for writing. """ - self._writer = pywrap_tensorflow.PyRecordWriter_New(compat.as_bytes(path)) + compression_type_string = "" + if options: + if options.compression_type == TFRecordCompressionType.ZLIB: + compression_type_string = "ZLIB" + + self._writer = pywrap_tensorflow.PyRecordWriter_New( + compat.as_bytes(path), compat.as_bytes(compression_type_string)) if self._writer is None: raise IOError("Could not write to %s." % path) diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index a5ed7d7a49..0483e0e7aa 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -137,6 +137,7 @@ from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.lib.io import python_io from tensorflow.python.ops import gen_io_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -523,13 +524,20 @@ class TFRecordReader(ReaderBase): """ # TODO(josh11b): Support serializing and restoring state. - def __init__(self, name=None): + def __init__(self, name=None, options=None): """Create a TFRecordReader. Args: name: A name for the operation (optional). + options: A TFRecordOptions object (optional). """ - rr = gen_io_ops._tf_record_reader(name=name) + compression_type_string = "" + if (options and + options.compression_type == python_io.TFRecordCompressionType.ZLIB): + compression_type_string = "ZLIB" + + rr = gen_io_ops._tf_record_reader(name=name, + compression_type=compression_type_string) super(TFRecordReader, self).__init__(rr) diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py index c30ed80481..509a371412 100644 --- a/tensorflow/python/summary/impl/event_file_loader.py +++ b/tensorflow/python/summary/impl/event_file_loader.py @@ -35,7 +35,7 @@ class EventFileLoader(object): file_path = resource_loader.readahead_file_path(file_path) logging.debug('Opening a record reader pointing at %s', file_path) self._reader = pywrap_tensorflow.PyRecordReader_New( - compat.as_bytes(file_path), 0) + compat.as_bytes(file_path), 0, compat.as_bytes('')) # Store it for logging purposes. self._file_path = file_path if not self._reader: diff --git a/tensorflow/python/summary/impl/gcs_file_loader.py b/tensorflow/python/summary/impl/gcs_file_loader.py index d820b02d42..b82bb85fc2 100644 --- a/tensorflow/python/summary/impl/gcs_file_loader.py +++ b/tensorflow/python/summary/impl/gcs_file_loader.py @@ -46,7 +46,8 @@ class GCSFileLoader(object): name = temp_file.name logging.debug('Temp file created at %s', name) gcs.CopyContents(self._gcs_path, self._gcs_offset, temp_file) - reader = pywrap_tensorflow.PyRecordReader_New(compat.as_bytes(name), 0) + reader = pywrap_tensorflow.PyRecordReader_New( + compat.as_bytes(name), 0, compat.as_bytes('')) while reader.GetNext(): event = event_pb2.Event() event.ParseFromString(reader.record()) |