aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-16 14:40:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-16 15:48:05 -0700
commit90bad51f97cc32dc9bc8e2ffb3e81b65cec37dfb (patch)
tree22621e351e2a91323e152f17039162872d07798e
parent598867b6a207be402cf3555697a212825b81a882 (diff)
Add support to TFRecordReader and writer to read/write from zlib compressed files.
Change: 125110928
-rw-r--r--tensorflow/core/BUILD3
-rw-r--r--tensorflow/core/kernels/tf_record_reader_op.cc23
-rw-r--r--tensorflow/core/lib/io/record_reader.cc83
-rw-r--r--tensorflow/core/lib/io/record_reader.h22
-rw-r--r--tensorflow/core/lib/io/record_reader_writer_test.cc73
-rw-r--r--tensorflow/core/lib/io/record_writer.cc43
-rw-r--r--tensorflow/core/lib/io/record_writer.h27
-rw-r--r--tensorflow/core/lib/io/zlib_compression_options.h13
-rw-r--r--tensorflow/core/lib/io/zlib_inputbuffer.cc5
-rw-r--r--tensorflow/core/lib/io/zlib_inputbuffer.h7
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.cc3
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.h7
-rw-r--r--tensorflow/core/ops/io_ops.cc1
-rw-r--r--tensorflow/core/util/events_writer.cc7
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py79
-rw-r--r--tensorflow/python/lib/io/py_record_reader.cc11
-rw-r--r--tensorflow/python/lib/io/py_record_reader.h6
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc10
-rw-r--r--tensorflow/python/lib/io/py_record_writer.h5
-rw-r--r--tensorflow/python/lib/io/tf_record.py36
-rw-r--r--tensorflow/python/ops/io_ops.py12
-rw-r--r--tensorflow/python/summary/impl/event_file_loader.py2
-rw-r--r--tensorflow/python/summary/impl/gcs_file_loader.py3
23 files changed, 422 insertions, 59 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index f1408d44d6..77db4004c7 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -633,7 +633,6 @@ filegroup(
"**/*testlib*",
"**/*main.cc",
"graph/dot.*",
- "lib/io/zlib*",
"lib/jpeg/**/*",
"lib/png/**/*",
"util/checkpoint_reader.*",
@@ -688,6 +687,7 @@ cc_library(
name = "android_tensorflow_lib",
srcs = if_android([":android_op_registrations_and_gradients"]),
copts = tf_copts(),
+ linkopts = ["-lz"],
tags = [
"manual",
"notap",
@@ -1270,6 +1270,7 @@ tf_cc_tests(
"lib/io/inputbuffer_test.cc",
"lib/io/match_test.cc",
"lib/io/path_test.cc",
+ "lib/io/record_reader_writer_test.cc",
"lib/io/recordio_test.cc",
"lib/io/table_test.cc",
"lib/io/zlib_buffers_test.cc",
diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc
index 9604d44567..3cbb334141 100644
--- a/tensorflow/core/kernels/tf_record_reader_op.cc
+++ b/tensorflow/core/kernels/tf_record_reader_op.cc
@@ -27,17 +27,25 @@ namespace tensorflow {
class TFRecordReader : public ReaderBase {
public:
- TFRecordReader(const string& node_name, Env* env)
+ TFRecordReader(const string& node_name, const string& compression_type,
+ Env* env)
: ReaderBase(strings::StrCat("TFRecordReader '", node_name, "'")),
env_(env),
- offset_(0) {}
+ offset_(0),
+ compression_type_(compression_type) {}
Status OnWorkStartedLocked() override {
offset_ = 0;
RandomAccessFile* file = nullptr;
TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file));
file_.reset(file);
- reader_.reset(new io::RecordReader(file));
+
+ io::RecordReaderOptions options;
+ if (compression_type_ == "ZLIB") {
+ options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
+ }
+
+ reader_.reset(new io::RecordReader(file, options));
return Status::OK();
}
@@ -74,6 +82,7 @@ class TFRecordReader : public ReaderBase {
uint64 offset_;
std::unique_ptr<RandomAccessFile> file_;
std::unique_ptr<io::RecordReader> reader_;
+ string compression_type_ = "";
};
class TFRecordReaderOp : public ReaderOpKernel {
@@ -81,7 +90,13 @@ class TFRecordReaderOp : public ReaderOpKernel {
explicit TFRecordReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
Env* env = context->env();
- SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
+
+ string compression_type;
+ context->GetAttr("compression_type", &compression_type);
+
+ SetReaderFactory([this, compression_type, env]() {
+ return new TFRecordReader(name(), compression_type, env);
+ });
}
};
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index 76011430de..eb194a14d4 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/record_reader.h"
#include <limits.h>
+
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/crc32c.h"
@@ -24,38 +25,76 @@ limitations under the License.
namespace tensorflow {
namespace io {
-RecordReader::RecordReader(RandomAccessFile* file) : src_(file) {}
+RecordReader::RecordReader(RandomAccessFile* file,
+ const RecordReaderOptions& options)
+ : src_(file), options_(options) {
+ if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
+ zlib_input_buffer_.reset(new ZlibInputBuffer(
+ src_, options.zlib_options.input_buffer_size,
+ options.zlib_options.output_buffer_size, options.zlib_options));
+ } else if (options.compression_type == RecordReaderOptions::NONE) {
+ // Nothing to do.
+ } else {
+ LOG(FATAL) << "Unspecified compression type :" << options.compression_type;
+ }
+}
RecordReader::~RecordReader() {}
// Read n+4 bytes from file, verify that checksum of first n bytes is
// stored in the last 4 bytes and store the first n bytes in *result.
// May use *storage as backing store.
-static Status ReadChecksummed(RandomAccessFile* file, uint64 offset, size_t n,
- StringPiece* result, string* storage) {
+Status RecordReader::ReadChecksummed(uint64 offset, size_t n,
+ StringPiece* result, string* storage) {
if (n >= SIZE_MAX - sizeof(uint32)) {
return errors::DataLoss("record size too large");
}
const size_t expected = n + sizeof(uint32);
storage->resize(expected);
- StringPiece data;
- Status s = file->Read(offset, expected, &data, &(*storage)[0]);
- if (!s.ok()) {
- return s;
- }
- if (data.size() != expected) {
- if (data.size() == 0) {
- return errors::OutOfRange("eof");
- } else {
- return errors::DataLoss("truncated record at ", offset);
+
+ if (zlib_input_buffer_) {
+ // If we have a zlib compressed buffer, we assume that the
+ // file is being read sequentially, and we use the underlying
+ // implementation to read the data.
+ //
+ // No checks are done to validate that the file is being read
+ // sequentially. At some point the zlib input buffer may support
+ // seeking, possibly inefficiently.
+ TF_RETURN_IF_ERROR(zlib_input_buffer_->ReadNBytes(expected, storage));
+
+ if (storage->size() != expected) {
+ if (storage->size() == 0) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
}
+
+ uint32 masked_crc = core::DecodeFixed32(storage->data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(storage->data(), n);
+ } else {
+ // This version supports reading from arbitrary offsets
+ // since we are accessing the random access file directly.
+ StringPiece data;
+ TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0]));
+ if (data.size() != expected) {
+ if (data.size() == 0) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
+ }
+ uint32 masked_crc = core::DecodeFixed32(data.data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(data.data(), n);
}
- uint32 masked_crc = core::DecodeFixed32(data.data() + n);
- if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
- return errors::DataLoss("corrupted record at ", offset);
- }
- *result = StringPiece(data.data(), n);
+
return Status::OK();
}
@@ -63,9 +102,9 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) {
static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static const size_t kFooterSize = sizeof(uint32);
- // Read length
+ // Read header data.
StringPiece lbuf;
- Status s = ReadChecksummed(src_, *offset, sizeof(uint64), &lbuf, record);
+ Status s = ReadChecksummed(*offset, sizeof(uint64), &lbuf, record);
if (!s.ok()) {
return s;
}
@@ -73,19 +112,21 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) {
// Read data
StringPiece data;
- s = ReadChecksummed(src_, *offset + kHeaderSize, length, &data, record);
+ s = ReadChecksummed(*offset + kHeaderSize, length, &data, record);
if (!s.ok()) {
if (errors::IsOutOfRange(s)) {
s = errors::DataLoss("truncated record at ", *offset);
}
return s;
}
+
if (record->data() != data.data()) {
// RandomAccessFile placed the data in some other location.
memmove(&(*record)[0], data.data(), data.size());
}
record->resize(data.size());
+
*offset += kHeaderSize + length + kFooterSize;
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index f96d24aa31..b4c56451be 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -18,6 +18,9 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
+#include "tensorflow/core/lib/io/zlib_inputbuffer.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -27,13 +30,23 @@ class RandomAccessFile;
namespace io {
+class RecordReaderOptions {
+ public:
+ enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 };
+ CompressionType compression_type = NONE;
+
+ // Options specific to zlib compression.
+ ZlibCompressionOptions zlib_options;
+};
+
class RecordReader {
public:
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
- explicit RecordReader(RandomAccessFile* file);
+ RecordReader(RandomAccessFile* file,
+ const RecordReaderOptions& options = RecordReaderOptions());
- ~RecordReader();
+ virtual ~RecordReader();
// Read the record at "*offset" into *record and update *offset to
// point to the offset of the next record. Returns OK on success,
@@ -41,7 +54,12 @@ class RecordReader {
Status ReadRecord(uint64* offset, string* record);
private:
+ Status ReadChecksummed(uint64 offset, size_t n, StringPiece* result,
+ string* storage);
+
RandomAccessFile* src_;
+ RecordReaderOptions options_;
+ std::unique_ptr<ZlibInputBuffer> zlib_input_buffer_;
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc
new file mode 100644
index 0000000000..ee8a7fa91a
--- /dev/null
+++ b/tensorflow/core/lib/io/record_reader_writer_test.cc
@@ -0,0 +1,73 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+
+#include <vector>
+#include "tensorflow/core/platform/env.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+static std::vector<int> BufferSizes() {
+ return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 65536};
+}
+
+TEST(RecordReaderWriterTest, TestBasics) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/record_reader_writer_test";
+
+ for (auto buf_size : BufferSizes()) {
+ WritableFile* file;
+ TF_CHECK_OK(env->NewWritableFile(fname, &file));
+
+ {
+ io::RecordWriterOptions options;
+ options.zlib_options.output_buffer_size = buf_size;
+ io::RecordWriter writer(file, options);
+ writer.WriteRecord("abc");
+ writer.WriteRecord("defg");
+ TF_CHECK_OK(writer.Flush());
+ }
+ delete file;
+
+ RandomAccessFile* read_file;
+ {
+ // Read it back with the RecordReader.
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &read_file));
+ io::RecordReaderOptions options;
+ options.zlib_options.input_buffer_size = buf_size;
+ io::RecordReader reader(read_file, options);
+ uint64 offset = 0;
+ string record;
+ TF_CHECK_OK(reader.ReadRecord(&offset, &record));
+ EXPECT_EQ("abc", record);
+ TF_CHECK_OK(reader.ReadRecord(&offset, &record));
+ EXPECT_EQ("defg", record);
+ }
+
+ delete read_file;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index e42f1256a4..7993f6ca20 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -22,9 +22,28 @@ limitations under the License.
namespace tensorflow {
namespace io {
-RecordWriter::RecordWriter(WritableFile* dest) : dest_(dest) {}
+RecordWriter::RecordWriter(WritableFile* dest,
+ const RecordWriterOptions& options)
+ : dest_(dest), options_(options) {
+ if (options.compression_type == RecordWriterOptions::ZLIB_COMPRESSION) {
+ zlib_output_buffer_.reset(new ZlibOutputBuffer(
+ dest_, options.zlib_options.input_buffer_size,
+ options.zlib_options.output_buffer_size, options.zlib_options));
+ } else if (options.compression_type == RecordWriterOptions::NONE) {
+ // Nothing to do
+ } else {
+ LOG(FATAL) << "Unspecified compression type :" << options.compression_type;
+ }
+}
-RecordWriter::~RecordWriter() {}
+RecordWriter::~RecordWriter() {
+ if (zlib_output_buffer_) {
+ Status s = zlib_output_buffer_->Close();
+ if (!s.ok()) {
+ LOG(ERROR) << "Could not finish writing file: " << s;
+ }
+ }
+}
static uint32 MaskedCrc(const char* data, size_t n) {
return crc32c::Mask(crc32c::Value(data, n));
@@ -40,17 +59,19 @@ Status RecordWriter::WriteRecord(StringPiece data) {
core::EncodeFixed64(header + 0, data.size());
core::EncodeFixed32(header + sizeof(uint64),
MaskedCrc(header, sizeof(uint64)));
- Status s = dest_->Append(StringPiece(header, sizeof(header)));
- if (!s.ok()) {
- return s;
- }
- s = dest_->Append(data);
- if (!s.ok()) {
- return s;
- }
char footer[sizeof(uint32)];
core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size()));
- return dest_->Append(StringPiece(footer, sizeof(footer)));
+
+ if (zlib_output_buffer_) {
+ TF_RETURN_IF_ERROR(
+ zlib_output_buffer_->Write(StringPiece(header, sizeof(header))));
+ TF_RETURN_IF_ERROR(zlib_output_buffer_->Write(data));
+ return zlib_output_buffer_->Write(StringPiece(footer, sizeof(footer)));
+ } else {
+ TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
+ TF_RETURN_IF_ERROR(dest_->Append(data));
+ return dest_->Append(StringPiece(footer, sizeof(footer)));
+ }
}
} // namespace io
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
index 93b2a08714..2344df3b25 100644
--- a/tensorflow/core/lib/io/record_writer.h
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -18,6 +18,8 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
+#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -27,19 +29,42 @@ class WritableFile;
namespace io {
+class RecordWriterOptions {
+ public:
+ enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 };
+ CompressionType compression_type = NONE;
+
+ // Options specific to zlib compression.
+ ZlibCompressionOptions zlib_options;
+};
+
class RecordWriter {
public:
// Create a writer that will append data to "*dest".
// "*dest" must be initially empty.
// "*dest" must remain live while this Writer is in use.
- explicit RecordWriter(WritableFile* dest);
+ RecordWriter(WritableFile* dest,
+ const RecordWriterOptions& options = RecordWriterOptions());
~RecordWriter();
Status WriteRecord(StringPiece slice);
+ // Flushes any buffered data held by underlying containers of the
+ // RecordWriter to the WritableFile. Does *not* flush the
+ // WritableFile.
+ Status Flush() {
+ if (zlib_output_buffer_) {
+ return zlib_output_buffer_->Flush();
+ }
+
+ return Status::OK();
+ }
+
private:
WritableFile* const dest_;
+ RecordWriterOptions options_;
+ std::unique_ptr<ZlibOutputBuffer> zlib_output_buffer_;
TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter);
};
diff --git a/tensorflow/core/lib/io/zlib_compression_options.h b/tensorflow/core/lib/io/zlib_compression_options.h
index 49afea4e36..95af0ab9c9 100644
--- a/tensorflow/core/lib/io/zlib_compression_options.h
+++ b/tensorflow/core/lib/io/zlib_compression_options.h
@@ -16,7 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_
#define TENSORFLOW_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_
+// TODO(srbs|vrv): Move to a platform/zlib.h file to centralize all
+// platform-specific includes
+#ifdef __ANDROID__
+#include "zlib.h"
+#else
#include <zlib.h>
+#endif // __ANDROID__
namespace tensorflow {
namespace io {
@@ -28,6 +34,13 @@ class ZlibCompressionOptions {
int8 flush_mode = Z_NO_FLUSH;
+ // Size of the buffer used for caching the data read from source file.
+ int64 input_buffer_size = 256 << 10;
+
+ // Size of the sink buffer where the compressed/decompressed data produced by
+ // zlib is cached.
+ int64 output_buffer_size = 256 << 10;
+
// The window_bits parameter is the base two logarithm of the window size
// (the size of the history buffer). Larger values of buffer size result in
// better compression at the expense of memory usage.
diff --git a/tensorflow/core/lib/io/zlib_inputbuffer.cc b/tensorflow/core/lib/io/zlib_inputbuffer.cc
index 5f20fba0e6..b5224168f1 100644
--- a/tensorflow/core/lib/io/zlib_inputbuffer.cc
+++ b/tensorflow/core/lib/io/zlib_inputbuffer.cc
@@ -128,6 +128,7 @@ size_t ZlibInputBuffer::NumUnreadBytes() const {
}
Status ZlibInputBuffer::ReadNBytes(int64 bytes_to_read, string* result) {
+ result->clear();
// Read as many bytes as possible from cache.
bytes_to_read -= ReadBytesFromCache(bytes_to_read, result);
@@ -163,8 +164,8 @@ Status ZlibInputBuffer::ReadNBytes(int64 bytes_to_read, string* result) {
Status ZlibInputBuffer::Inflate() {
int error = inflate(z_stream_.get(), zlib_options_.flush_mode);
if (error != Z_OK && error != Z_FINISH) {
- string error_string = strings::StrCat("inflate() failed with error ",
- std::to_string(error).c_str());
+ string error_string =
+ strings::StrCat("inflate() failed with error ", error);
if (z_stream_->msg != NULL) {
strings::StrAppend(&error_string, ": ", z_stream_->msg);
}
diff --git a/tensorflow/core/lib/io/zlib_inputbuffer.h b/tensorflow/core/lib/io/zlib_inputbuffer.h
index e21070134a..008d51876a 100644
--- a/tensorflow/core/lib/io/zlib_inputbuffer.h
+++ b/tensorflow/core/lib/io/zlib_inputbuffer.h
@@ -22,7 +22,14 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
+
+// TODO(srbs|vrv): Move to a platform/zlib.h file to centralize all
+// platform-specific includes
+#ifdef __ANDROID__
+#include "zlib.h"
+#else
#include <zlib.h>
+#endif // __ANDROID__
namespace tensorflow {
namespace io {
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc
index 6db9423c65..9493804bcb 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.cc
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc
@@ -200,8 +200,7 @@ Status ZlibOutputBuffer::Deflate(int flush) {
(error == Z_STREAM_END && flush == Z_FINISH)) {
return Status::OK();
}
- string error_string = strings::StrCat("deflate() failed with error ",
- std::to_string(error).c_str());
+ string error_string = strings::StrCat("deflate() failed with error ", error);
if (z_stream_->msg != NULL) {
strings::StrAppend(&error_string, ": ", z_stream_->msg);
}
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h
index a5c6b16bd9..08455b63b5 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.h
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.h
@@ -22,7 +22,14 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
+
+// TODO(srbs|vrv): Move to a platform/zlib.h file to centralize all
+// platform-specific includes.
+#ifdef __ANDROID__
+#include "zlib.h"
+#else
#include <zlib.h>
+#endif // __ANDROID__
namespace tensorflow {
namespace io {
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index 112a5278bf..10a0063299 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -217,6 +217,7 @@ REGISTER_OP("TFRecordReader")
.Output("reader_handle: Ref(string)")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("compression_type: string = ''")
.SetIsStateful()
.Doc(R"doc(
A Reader that outputs the records from a TensorFlow Records file.
diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc
index 21dca84f9c..c1f359c491 100644
--- a/tensorflow/core/util/events_writer.cc
+++ b/tensorflow/core/util/events_writer.cc
@@ -113,6 +113,13 @@ void EventsWriter::WriteEvent(const Event& event) {
bool EventsWriter::Flush() {
if (num_outstanding_events_ == 0) return true;
CHECK(recordio_file_.get() != NULL) << "Unexpected NULL file";
+
+ if (!recordio_writer_->Flush().ok()) {
+ LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to "
+ << filename_;
+ return false;
+ }
+
// The FileHasDisappeared() condition is necessary because
// recordio_writer_->Sync() can return true even if the underlying
// file has been deleted. EventWriter.FileDeletionBeforeWriting
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index a335344c83..8bcec44a61 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -410,6 +410,85 @@ class TFRecordReaderTest(tf.test.TestCase):
self.assertEqual(self._num_files * self._num_records, num_v)
+class TFRecordWriterZlibTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(TFRecordWriterZlibTest, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+
+ def _Record(self, f, r):
+ return tf.compat.as_bytes("Record %d of file %d" % (r, f))
+
+ def _CreateFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ options = tf.python_io.TFRecordOptions(
+ compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ writer = tf.python_io.TFRecordWriter(fn, options=options)
+ for j in range(self._num_records):
+ writer.write(self._Record(i, j))
+ writer.close()
+ del writer
+
+ return filenames
+
+ def testOneEpoch(self):
+ files = self._CreateFiles()
+ with self.test_session() as sess:
+ options = tf.python_io.TFRecordOptions(
+ compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ reader = tf.TFRecordReader(name="test_reader", options=options)
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ key, value = reader.read(queue)
+
+ queue.enqueue_many([files]).run()
+ queue.close().run()
+ for i in range(self._num_files):
+ for j in range(self._num_records):
+ k, v = sess.run([key, value])
+ self.assertTrue(tf.compat.as_text(k).startswith("%s:" % files[i]))
+ self.assertAllEqual(self._Record(i, j), v)
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = sess.run([key, value])
+
+
+class TFRecordIteratorTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(TFRecordIteratorTest, self).setUp()
+ self._num_records = 7
+
+ def _Record(self, r):
+ return tf.compat.as_bytes("Record %d" % r)
+
+ def _CreateFile(self):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.txt")
+ options = tf.python_io.TFRecordOptions(
+ compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ writer = tf.python_io.TFRecordWriter(fn, options=options)
+ for i in range(self._num_records):
+ writer.write(self._Record(i))
+ writer.close()
+ del writer
+ return fn
+
+ def testIterator(self):
+ fn = self._CreateFile()
+ options = tf.python_io.TFRecordOptions(
+ compression_type=tf.python_io.TFRecordCompressionType.ZLIB)
+ reader = tf.python_io.tf_record_iterator(fn, options)
+ for i in range(self._num_records):
+ record = next(reader)
+ self.assertAllEqual(self._Record(i), record)
+ with self.assertRaises(StopIteration):
+ record = next(reader)
+
+
class AsyncReaderTest(tf.test.TestCase):
def testNoDeadlockFromQueue(self):
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc
index a02e3bdba7..e82f1fb223 100644
--- a/tensorflow/python/lib/io/py_record_reader.cc
+++ b/tensorflow/python/lib/io/py_record_reader.cc
@@ -28,8 +28,8 @@ namespace io {
PyRecordReader::PyRecordReader() {}
-PyRecordReader* PyRecordReader::New(const string& filename,
- uint64 start_offset) {
+PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset,
+ const string& compression_type_string) {
RandomAccessFile* file;
Status s = Env::Default()->NewRandomAccessFile(filename, &file);
if (!s.ok()) {
@@ -38,7 +38,12 @@ PyRecordReader* PyRecordReader::New(const string& filename,
PyRecordReader* reader = new PyRecordReader;
reader->offset_ = start_offset;
reader->file_ = file;
- reader->reader_ = new RecordReader(reader->file_);
+
+ RecordReaderOptions options;
+ if (compression_type_string == "ZLIB") {
+ options.compression_type = RecordReaderOptions::ZLIB_COMPRESSION;
+ }
+ reader->reader_ = new RecordReader(reader->file_, options);
return reader;
}
diff --git a/tensorflow/python/lib/io/py_record_reader.h b/tensorflow/python/lib/io/py_record_reader.h
index 5d9c7108b6..a72cf04c3c 100644
--- a/tensorflow/python/lib/io/py_record_reader.h
+++ b/tensorflow/python/lib/io/py_record_reader.h
@@ -33,7 +33,11 @@ class RecordReader;
// by multiple threads.
class PyRecordReader {
public:
- static PyRecordReader* New(const string& filename, uint64 start_offset);
+ // TODO(vrv): make this take a shared proto to configure
+ // the compression options.
+ static PyRecordReader* New(const string& filename, uint64 start_offset,
+ const string& compression_type_string);
+
~PyRecordReader();
// Attempt to get the next record at "current_offset()". If
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index 6050ce4c6b..3ff8a3bff2 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -25,7 +25,8 @@ namespace io {
PyRecordWriter::PyRecordWriter() {}
-PyRecordWriter* PyRecordWriter::New(const string& filename) {
+PyRecordWriter* PyRecordWriter::New(const string& filename,
+ const string& compression_type_string) {
WritableFile* file;
Status s = Env::Default()->NewWritableFile(filename, &file);
if (!s.ok()) {
@@ -33,7 +34,12 @@ PyRecordWriter* PyRecordWriter::New(const string& filename) {
}
PyRecordWriter* writer = new PyRecordWriter;
writer->file_ = file;
- writer->writer_ = new RecordWriter(writer->file_);
+
+ RecordWriterOptions options;
+ if (compression_type_string == "ZLIB") {
+ options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION;
+ }
+ writer->writer_ = new RecordWriter(writer->file_, options);
return writer;
}
diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h
index e99381f492..86e2b9e56f 100644
--- a/tensorflow/python/lib/io/py_record_writer.h
+++ b/tensorflow/python/lib/io/py_record_writer.h
@@ -33,7 +33,10 @@ class RecordWriter;
// by multiple threads.
class PyRecordWriter {
public:
- static PyRecordWriter* New(const string& filename);
+ // TODO(vrv): make this take a shared proto to configure
+ // the compression options.
+ static PyRecordWriter* New(const string& filename,
+ const string& compression_type_string);
~PyRecordWriter();
bool WriteRecord(tensorflow::StringPiece record);
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index f792d4d1d8..4694dee5f6 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -23,11 +23,25 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.util import compat
-def tf_record_iterator(path):
+class TFRecordCompressionType(object):
+ NONE = 0
+ ZLIB = 1
+
+
+# NOTE(vrv): This will eventually be converted into a proto. to match
+# the interface used by the C++ RecordWriter.
+class TFRecordOptions(object):
+
+ def __init__(self, compression_type):
+ self.compression_type = compression_type
+
+
+def tf_record_iterator(path, options=None):
"""An iterator that read the records from a TFRecords file.
Args:
path: The path to the TFRecords file.
+ options: (optional) A TFRecordOptions object.
Yields:
Strings.
@@ -35,7 +49,14 @@ def tf_record_iterator(path):
Raises:
IOError: If `path` cannot be opened for reading.
"""
- reader = pywrap_tensorflow.PyRecordReader_New(compat.as_bytes(path), 0)
+ compression_type_string = ""
+ if options:
+ if options.compression_type == TFRecordCompressionType.ZLIB:
+ compression_type_string = "ZLIB"
+
+ reader = pywrap_tensorflow.PyRecordReader_New(
+ compat.as_bytes(path), 0, compat.as_bytes(compression_type_string))
+
if reader is None:
raise IOError("Could not open %s." % path)
while reader.GetNext():
@@ -54,16 +75,23 @@ class TFRecordWriter(object):
@@close
"""
# TODO(josh11b): Support appending?
- def __init__(self, path):
+ def __init__(self, path, options=None):
"""Opens file `path` and creates a `TFRecordWriter` writing to it.
Args:
path: The path to the TFRecords file.
+ options: (optional) A TFRecordOptions object.
Raises:
IOError: If `path` cannot be opened for writing.
"""
- self._writer = pywrap_tensorflow.PyRecordWriter_New(compat.as_bytes(path))
+ compression_type_string = ""
+ if options:
+ if options.compression_type == TFRecordCompressionType.ZLIB:
+ compression_type_string = "ZLIB"
+
+ self._writer = pywrap_tensorflow.PyRecordWriter_New(
+ compat.as_bytes(path), compat.as_bytes(compression_type_string))
if self._writer is None:
raise IOError("Could not write to %s." % path)
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index a5ed7d7a49..0483e0e7aa 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -137,6 +137,7 @@ from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import gen_io_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
@@ -523,13 +524,20 @@ class TFRecordReader(ReaderBase):
"""
# TODO(josh11b): Support serializing and restoring state.
- def __init__(self, name=None):
+ def __init__(self, name=None, options=None):
"""Create a TFRecordReader.
Args:
name: A name for the operation (optional).
+ options: A TFRecordOptions object (optional).
"""
- rr = gen_io_ops._tf_record_reader(name=name)
+ compression_type_string = ""
+ if (options and
+ options.compression_type == python_io.TFRecordCompressionType.ZLIB):
+ compression_type_string = "ZLIB"
+
+ rr = gen_io_ops._tf_record_reader(name=name,
+ compression_type=compression_type_string)
super(TFRecordReader, self).__init__(rr)
diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py
index c30ed80481..509a371412 100644
--- a/tensorflow/python/summary/impl/event_file_loader.py
+++ b/tensorflow/python/summary/impl/event_file_loader.py
@@ -35,7 +35,7 @@ class EventFileLoader(object):
file_path = resource_loader.readahead_file_path(file_path)
logging.debug('Opening a record reader pointing at %s', file_path)
self._reader = pywrap_tensorflow.PyRecordReader_New(
- compat.as_bytes(file_path), 0)
+ compat.as_bytes(file_path), 0, compat.as_bytes(''))
# Store it for logging purposes.
self._file_path = file_path
if not self._reader:
diff --git a/tensorflow/python/summary/impl/gcs_file_loader.py b/tensorflow/python/summary/impl/gcs_file_loader.py
index d820b02d42..b82bb85fc2 100644
--- a/tensorflow/python/summary/impl/gcs_file_loader.py
+++ b/tensorflow/python/summary/impl/gcs_file_loader.py
@@ -46,7 +46,8 @@ class GCSFileLoader(object):
name = temp_file.name
logging.debug('Temp file created at %s', name)
gcs.CopyContents(self._gcs_path, self._gcs_offset, temp_file)
- reader = pywrap_tensorflow.PyRecordReader_New(compat.as_bytes(name), 0)
+ reader = pywrap_tensorflow.PyRecordReader_New(
+ compat.as_bytes(name), 0, compat.as_bytes(''))
while reader.GetNext():
event = event_pb2.Event()
event.ParseFromString(reader.record())