diff options
author | 2016-09-06 14:39:08 -0800 | |
---|---|---|
committer | 2016-09-06 15:47:17 -0700 | |
commit | 680966059e7a5ddc70f1ec4f10e7b19c64c60e4b (patch) | |
tree | 98f0163fb8f625b102d650d6474de379ae79f141 /tensorflow | |
parent | d7bc08fd5f3969773c9717c9c5021c814ca90861 (diff) |
Added a check for output_buffer_size <= 1 for ZlibOutputBuffer. Also adding some tests for Zlib compression reading / writing.
Change: 132370925
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/lib/io/record_reader_writer_test.cc | 38 | ||||
-rw-r--r-- | tensorflow/core/lib/io/record_writer.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/lib/io/zlib_buffers_test.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/lib/io/zlib_outputbuffer.cc | 44 | ||||
-rw-r--r-- | tensorflow/core/lib/io/zlib_outputbuffer.h | 6 |
5 files changed, 80 insertions, 16 deletions
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc index a44c35d7fd..0a656473e4 100644 --- a/tensorflow/core/lib/io/record_reader_writer_test.cc +++ b/tensorflow/core/lib/io/record_reader_writer_test.cc @@ -67,4 +67,42 @@ TEST(RecordReaderWriterTest, TestBasics) { } } +TEST(RecordReaderWriterTest, TestZlib) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/record_reader_writer_zlib_test"; + + for (auto buf_size : BufferSizes()) { + // Zlib compression needs output buffer size > 1. + if (buf_size == 1) continue; + { + std::unique_ptr<WritableFile> file; + TF_CHECK_OK(env->NewWritableFile(fname, &file)); + + io::RecordWriterOptions options; + options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION; + options.zlib_options.output_buffer_size = buf_size; + io::RecordWriter writer(file.get(), options); + writer.WriteRecord("abc"); + writer.WriteRecord("defg"); + TF_CHECK_OK(writer.Flush()); + } + + { + std::unique_ptr<RandomAccessFile> read_file; + // Read it back with the RecordReader. + TF_CHECK_OK(env->NewRandomAccessFile(fname, &read_file)); + io::RecordReaderOptions options; + options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; + options.zlib_options.input_buffer_size = buf_size; + io::RecordReader reader(read_file.get(), options); + uint64 offset = 0; + string record; + TF_CHECK_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("abc", record); + TF_CHECK_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("defg", record); + } + } +} + } // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 25873b83ba..516332d2b7 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -33,6 +33,11 @@ RecordWriter::RecordWriter(WritableFile* dest, zlib_output_buffer_.reset(new ZlibOutputBuffer( dest_, options.zlib_options.input_buffer_size, options.zlib_options.output_buffer_size, options.zlib_options)); + Status s = zlib_output_buffer_->Init(); + if (!s.ok()) { + LOG(FATAL) << "Failed to initialize Zlib inputbuffer. Error: " + << s.ToString(); + } #endif // IS_SLIM_BUILD } else if (options.compression_type == RecordWriterOptions::NONE) { // Nothing to do diff --git a/tensorflow/core/lib/io/zlib_buffers_test.cc b/tensorflow/core/lib/io/zlib_buffers_test.cc index eaaf149759..1290e98ce2 100644 --- a/tensorflow/core/lib/io/zlib_buffers_test.cc +++ b/tensorflow/core/lib/io/zlib_buffers_test.cc @@ -73,6 +73,7 @@ void TestAllCombinations(CompressionOptions input_options, ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, output_options); + TF_CHECK_OK(out.Init()); TF_CHECK_OK(out.Write(StringPiece(data))); TF_CHECK_OK(out.Close()); @@ -120,6 +121,7 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size, TF_CHECK_OK(env->NewWritableFile(fname, &file_writer)); ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, output_options); + TF_CHECK_OK(out.Init()); for (int i = 0; i < num_writes; i++) { TF_CHECK_OK(out.Write(StringPiece(data))); @@ -172,6 +174,7 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) { string result; ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, output_options); + TF_CHECK_OK(out.Init()); TF_CHECK_OK(out.Write(StringPiece(data))); TF_CHECK_OK(out.Close()); diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc index 9493804bcb..bdedfd00e8 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.cc +++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_outputbuffer.h" +#include "tensorflow/core/lib/core/errors.h" + namespace tensorflow { namespace io { @@ -25,35 +27,45 @@ ZlibOutputBuffer::ZlibOutputBuffer( const ZlibCompressionOptions& zlib_options) // size of z_stream.next_out buffer : file_(file), + init_status_(), input_buffer_capacity_(input_buffer_bytes), output_buffer_capacity_(output_buffer_bytes), z_stream_input_(new Bytef[input_buffer_bytes]), z_stream_output_(new Bytef[output_buffer_bytes]), zlib_options_(zlib_options), - z_stream_(new z_stream) { + z_stream_(new z_stream) {} + +ZlibOutputBuffer::~ZlibOutputBuffer() { + if (z_stream_.get()) { + LOG(WARNING) << "ZlibOutputBuffer::Close() not called. Possible data loss"; + } +} + +Status ZlibOutputBuffer::Init() { + // Output buffer size should be greater than 1 because deflation needs atleast + // one byte for book keeping etc. + if (output_buffer_capacity_ <= 1) { + return errors::InvalidArgument( + "output_buffer_bytes should be greater than " + "1"); + } memset(z_stream_.get(), 0, sizeof(z_stream)); z_stream_->zalloc = Z_NULL; z_stream_->zfree = Z_NULL; z_stream_->opaque = Z_NULL; int status = - deflateInit2(z_stream_.get(), zlib_options.compression_level, - zlib_options.compression_method, zlib_options.window_bits, - zlib_options.mem_level, zlib_options.compression_strategy); + deflateInit2(z_stream_.get(), zlib_options_.compression_level, + zlib_options_.compression_method, zlib_options_.window_bits, + zlib_options_.mem_level, zlib_options_.compression_strategy); if (status != Z_OK) { - LOG(FATAL) << "deflateInit failed with status " << status; z_stream_.reset(NULL); - } else { - z_stream_->next_in = z_stream_input_.get(); - z_stream_->next_out = z_stream_output_.get(); - z_stream_->avail_in = 0; - z_stream_->avail_out = output_buffer_capacity_; - } -} - -ZlibOutputBuffer::~ZlibOutputBuffer() { - if (z_stream_.get()) { - LOG(WARNING) << "ZlibOutputBuffer::Close() not called. Possible data loss"; + return errors::InvalidArgument("deflateInit failed with status", status); } + z_stream_->next_in = z_stream_input_.get(); + z_stream_->next_out = z_stream_output_.get(); + z_stream_->avail_in = 0; + z_stream_->avail_out = output_buffer_capacity_; + return Status::OK(); } int32 ZlibOutputBuffer::AvailableInputSpace() const { diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h index 08455b63b5..a53c40b8fb 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.h +++ b/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -45,6 +45,7 @@ class ZlibOutputBuffer { // 2. the deflated output // with sizes `input_buffer_bytes` and `output_buffer_bytes` respectively. // Does not take ownership of `file`. + // output_buffer_bytes should be greater than 1. ZlibOutputBuffer( WritableFile* file, int32 input_buffer_bytes, // size of z_stream.next_in buffer @@ -53,6 +54,10 @@ class ZlibOutputBuffer { ~ZlibOutputBuffer(); + // Initializes some state necessary for the output buffer. This call is + // required before any other operation on the buffer. + Status Init(); + // Adds `data` to the compression pipeline. // // The input data is buffered in `z_stream_input_` and is compressed in bulk @@ -78,6 +83,7 @@ class ZlibOutputBuffer { private: WritableFile* file_; // Not owned + Status init_status_; size_t input_buffer_capacity_; size_t output_buffer_capacity_; |