diff options
author | Saurabh Saxena <srbs@google.com> | 2017-10-18 11:58:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-18 12:09:48 -0700 |
commit | f5ea388e48a38b935ebd36442f756c8974b7ce3f (patch) | |
tree | 701ee4039719113d837355c6c0091185bc4ed001 /tensorflow/core/lib/io | |
parent | f5d3bf42b892ecfbde2ce9eb45f00b76473c824a (diff) |
Implement ZlibInputStream::Tell() by keeping track of the number of bytes
consumed by the reader.
PiperOrigin-RevId: 172634455
Diffstat (limited to 'tensorflow/core/lib/io')
-rw-r--r-- | tensorflow/core/lib/io/zlib_buffers_test.cc | 172 | ||||
-rw-r--r-- | tensorflow/core/lib/io/zlib_inputstream.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/lib/io/zlib_inputstream.h | 3 |
3 files changed, 156 insertions, 27 deletions
diff --git a/tensorflow/core/lib/io/zlib_buffers_test.cc b/tensorflow/core/lib/io/zlib_buffers_test.cc index 66ee68a916..156c712db8 100644 --- a/tensorflow/core/lib/io/zlib_buffers_test.cc +++ b/tensorflow/core/lib/io/zlib_buffers_test.cc @@ -68,25 +68,25 @@ void TestAllCombinations(CompressionOptions input_options, for (auto input_buf_size : InputBufferSizes()) { for (auto output_buf_size : OutputBufferSizes()) { std::unique_ptr<WritableFile> file_writer; - TF_CHECK_OK(env->NewWritableFile(fname, &file_writer)); + TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer)); string result; ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, output_options); - TF_CHECK_OK(out.Init()); + TF_ASSERT_OK(out.Init()); - TF_CHECK_OK(out.Append(StringPiece(data))); - TF_CHECK_OK(out.Close()); - TF_CHECK_OK(file_writer->Flush()); - TF_CHECK_OK(file_writer->Close()); + TF_ASSERT_OK(out.Append(StringPiece(data))); + TF_ASSERT_OK(out.Close()); + TF_ASSERT_OK(file_writer->Flush()); + TF_ASSERT_OK(file_writer->Close()); std::unique_ptr<RandomAccessFile> file_reader; - TF_CHECK_OK(env->NewRandomAccessFile(fname, &file_reader)); + TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader)); std::unique_ptr<RandomAccessInputStream> input_stream( new RandomAccessInputStream(file_reader.get())); ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size, input_options); - TF_EXPECT_OK(in.ReadNBytes(data.size(), &result)); + TF_ASSERT_OK(in.ReadNBytes(data.size(), &result)); EXPECT_EQ(result, data); } } @@ -118,24 +118,24 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size, string actual_result; string expected_result; - TF_CHECK_OK(env->NewWritableFile(fname, &file_writer)); + TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer)); ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, output_options); - TF_CHECK_OK(out.Init()); + TF_ASSERT_OK(out.Init()); for (int i = 0; i < num_writes; i++) { - TF_CHECK_OK(out.Append(StringPiece(data))); + TF_ASSERT_OK(out.Append(StringPiece(data))); if (with_flush) { - TF_CHECK_OK(out.Flush()); + TF_ASSERT_OK(out.Flush()); } strings::StrAppend(&expected_result, data); } - TF_CHECK_OK(out.Close()); - TF_CHECK_OK(file_writer->Flush()); - TF_CHECK_OK(file_writer->Close()); + TF_ASSERT_OK(out.Close()); + TF_ASSERT_OK(file_writer->Flush()); + TF_ASSERT_OK(file_writer->Close()); std::unique_ptr<RandomAccessFile> file_reader; - TF_CHECK_OK(env->NewRandomAccessFile(fname, &file_reader)); + TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader)); std::unique_ptr<RandomAccessInputStream> input_stream( new RandomAccessInputStream(file_reader.get())); ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size, @@ -143,7 +143,7 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size, for (int i = 0; i < num_writes; i++) { string decompressed_output; - TF_EXPECT_OK(in.ReadNBytes(data.size(), &decompressed_output)); + TF_ASSERT_OK(in.ReadNBytes(data.size(), &decompressed_output)); strings::StrAppend(&actual_result, decompressed_output); } @@ -170,19 +170,19 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) { string data = GenTestString(10); std::unique_ptr<WritableFile> file_writer; - TF_CHECK_OK(env->NewWritableFile(fname, &file_writer)); + TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer)); string result; ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, output_options); - TF_CHECK_OK(out.Init()); + TF_ASSERT_OK(out.Init()); - TF_CHECK_OK(out.Append(StringPiece(data))); - TF_CHECK_OK(out.Close()); - TF_CHECK_OK(file_writer->Flush()); - TF_CHECK_OK(file_writer->Close()); + TF_ASSERT_OK(out.Append(StringPiece(data))); + TF_ASSERT_OK(out.Close()); + TF_ASSERT_OK(file_writer->Flush()); + TF_ASSERT_OK(file_writer->Close()); std::unique_ptr<RandomAccessFile> file_reader; - TF_CHECK_OK(env->NewRandomAccessFile(fname, &file_reader)); + TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader)); std::unique_ptr<RandomAccessInputStream> input_stream( new RandomAccessInputStream(file_reader.get())); ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size, @@ -192,5 +192,129 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) { CHECK(read_status.error_message().find("inflate() failed") != string::npos); } +void WriteCompressedFile(Env* env, const string& fname, int input_buf_size, + int output_buf_size, + const CompressionOptions& output_options, + const string& data) { + std::unique_ptr<WritableFile> file_writer; + TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer)); + + ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, + output_options); + TF_ASSERT_OK(out.Init()); + + TF_ASSERT_OK(out.Append(StringPiece(data))); + TF_ASSERT_OK(out.Close()); + TF_ASSERT_OK(file_writer->Flush()); + TF_ASSERT_OK(file_writer->Close()); +} + +void TestTell(CompressionOptions input_options, + CompressionOptions output_options) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/zlib_buffers_test"; + for (auto file_size : NumCopies()) { + string data = GenTestString(file_size); + for (auto input_buf_size : InputBufferSizes()) { + for (auto output_buf_size : OutputBufferSizes()) { + // Write the compressed file. + WriteCompressedFile(env, fname, input_buf_size, output_buf_size, + output_options, data); + + // Boiler-plate to set up ZlibInputStream. + std::unique_ptr<RandomAccessFile> file_reader; + TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader)); + std::unique_ptr<RandomAccessInputStream> input_stream( + new RandomAccessInputStream(file_reader.get())); + ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size, + input_options); + + string first_half(data, 0, data.size() / 2); + string bytes_read; + + // Read the first half of the uncompressed file and expect that Tell() + // returns half the uncompressed length of the file. + TF_ASSERT_OK(in.ReadNBytes(first_half.size(), &bytes_read)); + EXPECT_EQ(in.Tell(), first_half.size()); + EXPECT_EQ(bytes_read, first_half); + + // Read the remaining half of the uncompressed file and expect that + // Tell() points past the end of file. + string second_half; + TF_ASSERT_OK( + in.ReadNBytes(data.size() - first_half.size(), &second_half)); + EXPECT_EQ(in.Tell(), data.size()); + bytes_read.append(second_half); + + // Expect that the file is correctly read. + EXPECT_EQ(bytes_read, data); + } + } + } +} + +void TestSkipNBytes(CompressionOptions input_options, + CompressionOptions output_options) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/zlib_buffers_test"; + for (auto file_size : NumCopies()) { + string data = GenTestString(file_size); + for (auto input_buf_size : InputBufferSizes()) { + for (auto output_buf_size : OutputBufferSizes()) { + // Write the compressed file. + WriteCompressedFile(env, fname, input_buf_size, output_buf_size, + output_options, data); + + // Boiler-plate to set up ZlibInputStream. + std::unique_ptr<RandomAccessFile> file_reader; + TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader)); + std::unique_ptr<RandomAccessInputStream> input_stream( + new RandomAccessInputStream(file_reader.get())); + ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size, + input_options); + + size_t data_half_size = data.size() / 2; + string second_half(data, data_half_size, data.size() - data_half_size); + + // Skip past the first half of the file and expect Tell() returns + // correctly. + TF_ASSERT_OK(in.SkipNBytes(data_half_size)); + EXPECT_EQ(in.Tell(), data_half_size); + + // Expect that second half is read correctly and Tell() returns past + // end of file after reading complete file. + string bytes_read; + TF_ASSERT_OK(in.ReadNBytes(second_half.size(), &bytes_read)); + EXPECT_EQ(bytes_read, second_half); + EXPECT_EQ(in.Tell(), data.size()); + } + } + } +} + +TEST(ZlibInputStream, TellDefaultOptions) { + TestTell(CompressionOptions::DEFAULT(), CompressionOptions::DEFAULT()); +} + +TEST(ZlibInputStream, TellRawDeflate) { + TestTell(CompressionOptions::RAW(), CompressionOptions::RAW()); +} + +TEST(ZlibInputStream, TellGzip) { + TestTell(CompressionOptions::GZIP(), CompressionOptions::GZIP()); +} + +TEST(ZlibInputStream, SkipNBytesDefaultOptions) { + TestSkipNBytes(CompressionOptions::DEFAULT(), CompressionOptions::DEFAULT()); +} + +TEST(ZlibInputStream, SkipNBytesRawDeflate) { + TestSkipNBytes(CompressionOptions::RAW(), CompressionOptions::RAW()); +} + +TEST(ZlibInputStream, SkipNBytesGzip) { + TestSkipNBytes(CompressionOptions::GZIP(), CompressionOptions::GZIP()); +} + } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/zlib_inputstream.cc b/tensorflow/core/lib/io/zlib_inputstream.cc index 4999d5cc90..984fbc2810 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.cc +++ b/tensorflow/core/lib/io/zlib_inputstream.cc @@ -32,7 +32,8 @@ ZlibInputStream::ZlibInputStream( z_stream_input_(new Bytef[input_buffer_capacity_]), z_stream_output_(new Bytef[output_buffer_capacity_]), zlib_options_(zlib_options), - z_stream_(new z_stream) { + z_stream_(new z_stream), + bytes_read_(0) { InitZlibBuffer(); } @@ -45,6 +46,7 @@ ZlibInputStream::~ZlibInputStream() { Status ZlibInputStream::Reset() { TF_RETURN_IF_ERROR(input_stream_->Reset()); InitZlibBuffer(); + bytes_read_ = 0; return Status::OK(); } @@ -127,6 +129,7 @@ size_t ZlibInputStream::ReadBytesFromCache(size_t bytes_to_read, result->append(next_unread_byte_, can_read_bytes); next_unread_byte_ += can_read_bytes; } + bytes_read_ += can_read_bytes; return can_read_bytes; } @@ -170,8 +173,7 @@ Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, string* result) { return Status::OK(); } -// TODO(srbs): Implement this. -int64 ZlibInputStream::Tell() const { return -1; } +int64 ZlibInputStream::Tell() const { return bytes_read_; } Status ZlibInputStream::Inflate() { int error = inflate(z_stream_.get(), zlib_options_.flush_mode); diff --git a/tensorflow/core/lib/io/zlib_inputstream.h b/tensorflow/core/lib/io/zlib_inputstream.h index 8faa7dcb8f..9c7e14441c 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.h +++ b/tensorflow/core/lib/io/zlib_inputstream.h @@ -132,6 +132,9 @@ class ZlibInputStream : public InputStreamInterface { // Returns the size of [next_unread_byte_, z_stream_->next_out) size_t NumUnreadBytes() const; + // Number of *uncompressed* bytes that have been read from this stream. + int64 bytes_read_; + TF_DISALLOW_COPY_AND_ASSIGN(ZlibInputStream); }; |