aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/io/record_reader.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib/io/record_reader.cc')
-rw-r--r--tensorflow/core/lib/io/record_reader.cc83
1 files changed, 62 insertions, 21 deletions
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index 76011430de..eb194a14d4 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/record_reader.h"
#include <limits.h>
+
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/crc32c.h"
@@ -24,38 +25,76 @@ limitations under the License.
namespace tensorflow {
namespace io {
-RecordReader::RecordReader(RandomAccessFile* file) : src_(file) {}
+RecordReader::RecordReader(RandomAccessFile* file,
+ const RecordReaderOptions& options)
+ : src_(file), options_(options) {
+ if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
+ zlib_input_buffer_.reset(new ZlibInputBuffer(
+ src_, options.zlib_options.input_buffer_size,
+ options.zlib_options.output_buffer_size, options.zlib_options));
+ } else if (options.compression_type == RecordReaderOptions::NONE) {
+ // Nothing to do.
+ } else {
+ LOG(FATAL) << "Unspecified compression type :" << options.compression_type;
+ }
+}
RecordReader::~RecordReader() {}
// Read n+4 bytes from file, verify that checksum of first n bytes is
// stored in the last 4 bytes and store the first n bytes in *result.
// May use *storage as backing store.
-static Status ReadChecksummed(RandomAccessFile* file, uint64 offset, size_t n,
- StringPiece* result, string* storage) {
+Status RecordReader::ReadChecksummed(uint64 offset, size_t n,
+ StringPiece* result, string* storage) {
if (n >= SIZE_MAX - sizeof(uint32)) {
return errors::DataLoss("record size too large");
}
const size_t expected = n + sizeof(uint32);
storage->resize(expected);
- StringPiece data;
- Status s = file->Read(offset, expected, &data, &(*storage)[0]);
- if (!s.ok()) {
- return s;
- }
- if (data.size() != expected) {
- if (data.size() == 0) {
- return errors::OutOfRange("eof");
- } else {
- return errors::DataLoss("truncated record at ", offset);
+
+ if (zlib_input_buffer_) {
+ // If we have a zlib compressed buffer, we assume that the
+ // file is being read sequentially, and we use the underlying
+ // implementation to read the data.
+ //
+ // No checks are done to validate that the file is being read
+ // sequentially. At some point the zlib input buffer may support
+ // seeking, possibly inefficiently.
+ TF_RETURN_IF_ERROR(zlib_input_buffer_->ReadNBytes(expected, storage));
+
+ if (storage->size() != expected) {
+ if (storage->size() == 0) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
}
+
+ uint32 masked_crc = core::DecodeFixed32(storage->data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(storage->data(), n);
+ } else {
+ // This version supports reading from arbitrary offsets
+ // since we are accessing the random access file directly.
+ StringPiece data;
+ TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0]));
+ if (data.size() != expected) {
+ if (data.size() == 0) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
+ }
+ uint32 masked_crc = core::DecodeFixed32(data.data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(data.data(), n);
}
- uint32 masked_crc = core::DecodeFixed32(data.data() + n);
- if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
- return errors::DataLoss("corrupted record at ", offset);
- }
- *result = StringPiece(data.data(), n);
+
return Status::OK();
}
@@ -63,9 +102,9 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) {
static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static const size_t kFooterSize = sizeof(uint32);
- // Read length
+ // Read header data.
StringPiece lbuf;
- Status s = ReadChecksummed(src_, *offset, sizeof(uint64), &lbuf, record);
+ Status s = ReadChecksummed(*offset, sizeof(uint64), &lbuf, record);
if (!s.ok()) {
return s;
}
@@ -73,19 +112,21 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) {
// Read data
StringPiece data;
- s = ReadChecksummed(src_, *offset + kHeaderSize, length, &data, record);
+ s = ReadChecksummed(*offset + kHeaderSize, length, &data, record);
if (!s.ok()) {
if (errors::IsOutOfRange(s)) {
s = errors::DataLoss("truncated record at ", *offset);
}
return s;
}
+
if (record->data() != data.data()) {
// RandomAccessFile placed the data in some other location.
memmove(&(*record)[0], data.data(), data.size());
}
record->resize(data.size());
+
*offset += kHeaderSize + length + kFooterSize;
return Status::OK();
}