aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/io
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2017-10-18 11:58:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-18 12:09:48 -0700
commitf5ea388e48a38b935ebd36442f756c8974b7ce3f (patch)
tree701ee4039719113d837355c6c0091185bc4ed001 /tensorflow/core/lib/io
parentf5d3bf42b892ecfbde2ce9eb45f00b76473c824a (diff)
Implement ZlibInputStream::Tell() by keeping track of the number of bytes
consumed by the reader. PiperOrigin-RevId: 172634455
Diffstat (limited to 'tensorflow/core/lib/io')
-rw-r--r--tensorflow/core/lib/io/zlib_buffers_test.cc172
-rw-r--r--tensorflow/core/lib/io/zlib_inputstream.cc8
-rw-r--r--tensorflow/core/lib/io/zlib_inputstream.h3
3 files changed, 156 insertions, 27 deletions
diff --git a/tensorflow/core/lib/io/zlib_buffers_test.cc b/tensorflow/core/lib/io/zlib_buffers_test.cc
index 66ee68a916..156c712db8 100644
--- a/tensorflow/core/lib/io/zlib_buffers_test.cc
+++ b/tensorflow/core/lib/io/zlib_buffers_test.cc
@@ -68,25 +68,25 @@ void TestAllCombinations(CompressionOptions input_options,
for (auto input_buf_size : InputBufferSizes()) {
for (auto output_buf_size : OutputBufferSizes()) {
std::unique_ptr<WritableFile> file_writer;
- TF_CHECK_OK(env->NewWritableFile(fname, &file_writer));
+ TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer));
string result;
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
output_options);
- TF_CHECK_OK(out.Init());
+ TF_ASSERT_OK(out.Init());
- TF_CHECK_OK(out.Append(StringPiece(data)));
- TF_CHECK_OK(out.Close());
- TF_CHECK_OK(file_writer->Flush());
- TF_CHECK_OK(file_writer->Close());
+ TF_ASSERT_OK(out.Append(StringPiece(data)));
+ TF_ASSERT_OK(out.Close());
+ TF_ASSERT_OK(file_writer->Flush());
+ TF_ASSERT_OK(file_writer->Close());
std::unique_ptr<RandomAccessFile> file_reader;
- TF_CHECK_OK(env->NewRandomAccessFile(fname, &file_reader));
+ TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader));
std::unique_ptr<RandomAccessInputStream> input_stream(
new RandomAccessInputStream(file_reader.get()));
ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size,
input_options);
- TF_EXPECT_OK(in.ReadNBytes(data.size(), &result));
+ TF_ASSERT_OK(in.ReadNBytes(data.size(), &result));
EXPECT_EQ(result, data);
}
}
@@ -118,24 +118,24 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size,
string actual_result;
string expected_result;
- TF_CHECK_OK(env->NewWritableFile(fname, &file_writer));
+ TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer));
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
output_options);
- TF_CHECK_OK(out.Init());
+ TF_ASSERT_OK(out.Init());
for (int i = 0; i < num_writes; i++) {
- TF_CHECK_OK(out.Append(StringPiece(data)));
+ TF_ASSERT_OK(out.Append(StringPiece(data)));
if (with_flush) {
- TF_CHECK_OK(out.Flush());
+ TF_ASSERT_OK(out.Flush());
}
strings::StrAppend(&expected_result, data);
}
- TF_CHECK_OK(out.Close());
- TF_CHECK_OK(file_writer->Flush());
- TF_CHECK_OK(file_writer->Close());
+ TF_ASSERT_OK(out.Close());
+ TF_ASSERT_OK(file_writer->Flush());
+ TF_ASSERT_OK(file_writer->Close());
std::unique_ptr<RandomAccessFile> file_reader;
- TF_CHECK_OK(env->NewRandomAccessFile(fname, &file_reader));
+ TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader));
std::unique_ptr<RandomAccessInputStream> input_stream(
new RandomAccessInputStream(file_reader.get()));
ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size,
@@ -143,7 +143,7 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size,
for (int i = 0; i < num_writes; i++) {
string decompressed_output;
- TF_EXPECT_OK(in.ReadNBytes(data.size(), &decompressed_output));
+ TF_ASSERT_OK(in.ReadNBytes(data.size(), &decompressed_output));
strings::StrAppend(&actual_result, decompressed_output);
}
@@ -170,19 +170,19 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) {
string data = GenTestString(10);
std::unique_ptr<WritableFile> file_writer;
- TF_CHECK_OK(env->NewWritableFile(fname, &file_writer));
+ TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer));
string result;
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
output_options);
- TF_CHECK_OK(out.Init());
+ TF_ASSERT_OK(out.Init());
- TF_CHECK_OK(out.Append(StringPiece(data)));
- TF_CHECK_OK(out.Close());
- TF_CHECK_OK(file_writer->Flush());
- TF_CHECK_OK(file_writer->Close());
+ TF_ASSERT_OK(out.Append(StringPiece(data)));
+ TF_ASSERT_OK(out.Close());
+ TF_ASSERT_OK(file_writer->Flush());
+ TF_ASSERT_OK(file_writer->Close());
std::unique_ptr<RandomAccessFile> file_reader;
- TF_CHECK_OK(env->NewRandomAccessFile(fname, &file_reader));
+ TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader));
std::unique_ptr<RandomAccessInputStream> input_stream(
new RandomAccessInputStream(file_reader.get()));
ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size,
@@ -192,5 +192,129 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) {
CHECK(read_status.error_message().find("inflate() failed") != string::npos);
}
+void WriteCompressedFile(Env* env, const string& fname, int input_buf_size,
+ int output_buf_size,
+ const CompressionOptions& output_options,
+ const string& data) {
+ std::unique_ptr<WritableFile> file_writer;
+ TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer));
+
+ ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
+ output_options);
+ TF_ASSERT_OK(out.Init());
+
+ TF_ASSERT_OK(out.Append(StringPiece(data)));
+ TF_ASSERT_OK(out.Close());
+ TF_ASSERT_OK(file_writer->Flush());
+ TF_ASSERT_OK(file_writer->Close());
+}
+
+void TestTell(CompressionOptions input_options,
+ CompressionOptions output_options) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/zlib_buffers_test";
+ for (auto file_size : NumCopies()) {
+ string data = GenTestString(file_size);
+ for (auto input_buf_size : InputBufferSizes()) {
+ for (auto output_buf_size : OutputBufferSizes()) {
+ // Write the compressed file.
+ WriteCompressedFile(env, fname, input_buf_size, output_buf_size,
+ output_options, data);
+
+ // Boiler-plate to set up ZlibInputStream.
+ std::unique_ptr<RandomAccessFile> file_reader;
+ TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader));
+ std::unique_ptr<RandomAccessInputStream> input_stream(
+ new RandomAccessInputStream(file_reader.get()));
+ ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size,
+ input_options);
+
+ string first_half(data, 0, data.size() / 2);
+ string bytes_read;
+
+ // Read the first half of the uncompressed file and expect that Tell()
+ // returns half the uncompressed length of the file.
+ TF_ASSERT_OK(in.ReadNBytes(first_half.size(), &bytes_read));
+ EXPECT_EQ(in.Tell(), first_half.size());
+ EXPECT_EQ(bytes_read, first_half);
+
+ // Read the remaining half of the uncompressed file and expect that
+ // Tell() points past the end of file.
+ string second_half;
+ TF_ASSERT_OK(
+ in.ReadNBytes(data.size() - first_half.size(), &second_half));
+ EXPECT_EQ(in.Tell(), data.size());
+ bytes_read.append(second_half);
+
+ // Expect that the file is correctly read.
+ EXPECT_EQ(bytes_read, data);
+ }
+ }
+ }
+}
+
+void TestSkipNBytes(CompressionOptions input_options,
+ CompressionOptions output_options) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/zlib_buffers_test";
+ for (auto file_size : NumCopies()) {
+ string data = GenTestString(file_size);
+ for (auto input_buf_size : InputBufferSizes()) {
+ for (auto output_buf_size : OutputBufferSizes()) {
+ // Write the compressed file.
+ WriteCompressedFile(env, fname, input_buf_size, output_buf_size,
+ output_options, data);
+
+ // Boiler-plate to set up ZlibInputStream.
+ std::unique_ptr<RandomAccessFile> file_reader;
+ TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file_reader));
+ std::unique_ptr<RandomAccessInputStream> input_stream(
+ new RandomAccessInputStream(file_reader.get()));
+ ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size,
+ input_options);
+
+ size_t data_half_size = data.size() / 2;
+ string second_half(data, data_half_size, data.size() - data_half_size);
+
+ // Skip past the first half of the file and expect Tell() returns
+ // correctly.
+ TF_ASSERT_OK(in.SkipNBytes(data_half_size));
+ EXPECT_EQ(in.Tell(), data_half_size);
+
+ // Expect that second half is read correctly and Tell() returns past
+ // end of file after reading complete file.
+ string bytes_read;
+ TF_ASSERT_OK(in.ReadNBytes(second_half.size(), &bytes_read));
+ EXPECT_EQ(bytes_read, second_half);
+ EXPECT_EQ(in.Tell(), data.size());
+ }
+ }
+ }
+}
+
+TEST(ZlibInputStream, TellDefaultOptions) {
+ TestTell(CompressionOptions::DEFAULT(), CompressionOptions::DEFAULT());
+}
+
+TEST(ZlibInputStream, TellRawDeflate) {
+ TestTell(CompressionOptions::RAW(), CompressionOptions::RAW());
+}
+
+TEST(ZlibInputStream, TellGzip) {
+ TestTell(CompressionOptions::GZIP(), CompressionOptions::GZIP());
+}
+
+TEST(ZlibInputStream, SkipNBytesDefaultOptions) {
+ TestSkipNBytes(CompressionOptions::DEFAULT(), CompressionOptions::DEFAULT());
+}
+
+TEST(ZlibInputStream, SkipNBytesRawDeflate) {
+ TestSkipNBytes(CompressionOptions::RAW(), CompressionOptions::RAW());
+}
+
+TEST(ZlibInputStream, SkipNBytesGzip) {
+ TestSkipNBytes(CompressionOptions::GZIP(), CompressionOptions::GZIP());
+}
+
} // namespace io
} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/zlib_inputstream.cc b/tensorflow/core/lib/io/zlib_inputstream.cc
index 4999d5cc90..984fbc2810 100644
--- a/tensorflow/core/lib/io/zlib_inputstream.cc
+++ b/tensorflow/core/lib/io/zlib_inputstream.cc
@@ -32,7 +32,8 @@ ZlibInputStream::ZlibInputStream(
z_stream_input_(new Bytef[input_buffer_capacity_]),
z_stream_output_(new Bytef[output_buffer_capacity_]),
zlib_options_(zlib_options),
- z_stream_(new z_stream) {
+ z_stream_(new z_stream),
+ bytes_read_(0) {
InitZlibBuffer();
}
@@ -45,6 +46,7 @@ ZlibInputStream::~ZlibInputStream() {
Status ZlibInputStream::Reset() {
TF_RETURN_IF_ERROR(input_stream_->Reset());
InitZlibBuffer();
+ bytes_read_ = 0;
return Status::OK();
}
@@ -127,6 +129,7 @@ size_t ZlibInputStream::ReadBytesFromCache(size_t bytes_to_read,
result->append(next_unread_byte_, can_read_bytes);
next_unread_byte_ += can_read_bytes;
}
+ bytes_read_ += can_read_bytes;
return can_read_bytes;
}
@@ -170,8 +173,7 @@ Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, string* result) {
return Status::OK();
}
-// TODO(srbs): Implement this.
-int64 ZlibInputStream::Tell() const { return -1; }
+int64 ZlibInputStream::Tell() const { return bytes_read_; }
Status ZlibInputStream::Inflate() {
int error = inflate(z_stream_.get(), zlib_options_.flush_mode);
diff --git a/tensorflow/core/lib/io/zlib_inputstream.h b/tensorflow/core/lib/io/zlib_inputstream.h
index 8faa7dcb8f..9c7e14441c 100644
--- a/tensorflow/core/lib/io/zlib_inputstream.h
+++ b/tensorflow/core/lib/io/zlib_inputstream.h
@@ -132,6 +132,9 @@ class ZlibInputStream : public InputStreamInterface {
// Returns the size of [next_unread_byte_, z_stream_->next_out)
size_t NumUnreadBytes() const;
+ // Number of *uncompressed* bytes that have been read from this stream.
+ int64 bytes_read_;
+
TF_DISALLOW_COPY_AND_ASSIGN(ZlibInputStream);
};