diff options
Diffstat (limited to 'tensorflow/core/lib/io/record_reader.cc')
-rw-r--r-- | tensorflow/core/lib/io/record_reader.cc | 83 |
1 files changed, 62 insertions, 21 deletions
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index 76011430de..eb194a14d4 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/lib/io/record_reader.h" #include <limits.h> + #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/crc32c.h" @@ -24,38 +25,76 @@ limitations under the License. namespace tensorflow { namespace io { -RecordReader::RecordReader(RandomAccessFile* file) : src_(file) {} +RecordReader::RecordReader(RandomAccessFile* file, + const RecordReaderOptions& options) + : src_(file), options_(options) { + if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) { + zlib_input_buffer_.reset(new ZlibInputBuffer( + src_, options.zlib_options.input_buffer_size, + options.zlib_options.output_buffer_size, options.zlib_options)); + } else if (options.compression_type == RecordReaderOptions::NONE) { + // Nothing to do. + } else { + LOG(FATAL) << "Unspecified compression type :" << options.compression_type; + } +} RecordReader::~RecordReader() {} // Read n+4 bytes from file, verify that checksum of first n bytes is // stored in the last 4 bytes and store the first n bytes in *result. // May use *storage as backing store. -static Status ReadChecksummed(RandomAccessFile* file, uint64 offset, size_t n, - StringPiece* result, string* storage) { +Status RecordReader::ReadChecksummed(uint64 offset, size_t n, + StringPiece* result, string* storage) { if (n >= SIZE_MAX - sizeof(uint32)) { return errors::DataLoss("record size too large"); } const size_t expected = n + sizeof(uint32); storage->resize(expected); - StringPiece data; - Status s = file->Read(offset, expected, &data, &(*storage)[0]); - if (!s.ok()) { - return s; - } - if (data.size() != expected) { - if (data.size() == 0) { - return errors::OutOfRange("eof"); - } else { - return errors::DataLoss("truncated record at ", offset); + + if (zlib_input_buffer_) { + // If we have a zlib compressed buffer, we assume that the + // file is being read sequentially, and we use the underlying + // implementation to read the data. + // + // No checks are done to validate that the file is being read + // sequentially. At some point the zlib input buffer may support + // seeking, possibly inefficiently. + TF_RETURN_IF_ERROR(zlib_input_buffer_->ReadNBytes(expected, storage)); + + if (storage->size() != expected) { + if (storage->size() == 0) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } } + + uint32 masked_crc = core::DecodeFixed32(storage->data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(storage->data(), n); + } else { + // This version supports reading from arbitrary offsets + // since we are accessing the random access file directly. + StringPiece data; + TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0])); + if (data.size() != expected) { + if (data.size() == 0) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } + } + uint32 masked_crc = core::DecodeFixed32(data.data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(data.data(), n); } - uint32 masked_crc = core::DecodeFixed32(data.data() + n); - if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) { - return errors::DataLoss("corrupted record at ", offset); - } - *result = StringPiece(data.data(), n); + return Status::OK(); } @@ -63,9 +102,9 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) { static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); static const size_t kFooterSize = sizeof(uint32); - // Read length + // Read header data. StringPiece lbuf; - Status s = ReadChecksummed(src_, *offset, sizeof(uint64), &lbuf, record); + Status s = ReadChecksummed(*offset, sizeof(uint64), &lbuf, record); if (!s.ok()) { return s; } @@ -73,19 +112,21 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) { // Read data StringPiece data; - s = ReadChecksummed(src_, *offset + kHeaderSize, length, &data, record); + s = ReadChecksummed(*offset + kHeaderSize, length, &data, record); if (!s.ok()) { if (errors::IsOutOfRange(s)) { s = errors::DataLoss("truncated record at ", *offset); } return s; } + if (record->data() != data.data()) { // RandomAccessFile placed the data in some other location. memmove(&(*record)[0], data.data(), data.size()); } record->resize(data.size()); + *offset += kHeaderSize + length + kFooterSize; return Status::OK(); } |