diff options
Diffstat (limited to 'tensorflow/core/lib/io/record_reader.cc')
-rw-r--r-- | tensorflow/core/lib/io/record_reader.cc | 147 |
1 files changed, 105 insertions, 42 deletions
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index c24628be57..6de850bb20 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -56,55 +56,110 @@ RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions( RecordReader::RecordReader(RandomAccessFile* file, const RecordReaderOptions& options) - : options_(options), - input_stream_(new RandomAccessInputStream(file)), - last_read_failed_(false) { + : src_(file), options_(options) { if (options.buffer_size > 0) { - input_stream_.reset(new BufferedInputStream(input_stream_.release(), - options.buffer_size, true)); + input_stream_.reset(new BufferedInputStream(file, options.buffer_size)); + } else { + input_stream_.reset(new RandomAccessInputStream(file)); } if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) { // We don't have zlib available on all embedded platforms, so fail. #if defined(IS_SLIM_BUILD) LOG(FATAL) << "Zlib compression is unsupported on mobile platforms."; #else // IS_SLIM_BUILD - input_stream_.reset(new ZlibInputStream( - input_stream_.release(), options.zlib_options.input_buffer_size, - options.zlib_options.output_buffer_size, options.zlib_options, true)); + zlib_input_stream_.reset(new ZlibInputStream( + input_stream_.get(), options.zlib_options.input_buffer_size, + options.zlib_options.output_buffer_size, options.zlib_options)); #endif // IS_SLIM_BUILD } else if (options.compression_type == RecordReaderOptions::NONE) { // Nothing to do. } else { - LOG(FATAL) << "Unrecognized compression type :" << options.compression_type; + LOG(FATAL) << "Unspecified compression type :" << options.compression_type; } } // Read n+4 bytes from file, verify that checksum of first n bytes is // stored in the last 4 bytes and store the first n bytes in *result. -// -// offset corresponds to the user-provided value to ReadRecord() -// and is used only in error messages. -Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) { +// May use *storage as backing store. +Status RecordReader::ReadChecksummed(uint64 offset, size_t n, + StringPiece* result, string* storage) { if (n >= SIZE_MAX - sizeof(uint32)) { return errors::DataLoss("record size too large"); } const size_t expected = n + sizeof(uint32); - TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, result)); + storage->resize(expected); + +#if !defined(IS_SLIM_BUILD) + if (zlib_input_stream_) { + // If we have a zlib compressed buffer, we assume that the + // file is being read sequentially, and we use the underlying + // implementation to read the data. + // + // No checks are done to validate that the file is being read + // sequentially. At some point the zlib input buffer may support + // seeking, possibly inefficiently. + TF_RETURN_IF_ERROR(zlib_input_stream_->ReadNBytes(expected, storage)); + + if (storage->size() != expected) { + if (storage->empty()) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } + } - if (result->size() != expected) { - if (result->empty()) { - return errors::OutOfRange("eof"); + uint32 masked_crc = core::DecodeFixed32(storage->data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(storage->data(), n); + } else { +#endif // IS_SLIM_BUILD + if (options_.buffer_size > 0) { + // If we have a buffer, we assume that the file is being read + // sequentially, and we use the underlying implementation to read the + // data. + // + // No checks are done to validate that the file is being read + // sequentially. + TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, storage)); + + if (storage->size() != expected) { + if (storage->empty()) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } + } + + const uint32 masked_crc = core::DecodeFixed32(storage->data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(storage->data(), n); } else { - return errors::DataLoss("truncated record at ", offset); + // This version supports reading from arbitrary offsets + // since we are accessing the random access file directly. + StringPiece data; + TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0])); + if (data.size() != expected) { + if (data.empty()) { + return errors::OutOfRange("eof"); + } else { + return errors::DataLoss("truncated record at ", offset); + } + } + const uint32 masked_crc = core::DecodeFixed32(data.data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + *result = StringPiece(data.data(), n); } +#if !defined(IS_SLIM_BUILD) } +#endif // IS_SLIM_BUILD - const uint32 masked_crc = core::DecodeFixed32(result->data() + n); - if (crc32c::Unmask(masked_crc) != crc32c::Value(result->data(), n)) { - return errors::DataLoss("corrupted record at ", offset); - } - result->resize(n); return Status::OK(); } @@ -112,42 +167,50 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) { static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); static const size_t kFooterSize = sizeof(uint32); - // Position the input stream. - int64 curr_pos = input_stream_->Tell(); - int64 desired_pos = static_cast<int64>(*offset); - if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ || - (curr_pos == desired_pos && last_read_failed_)) { - last_read_failed_ = false; - TF_RETURN_IF_ERROR(input_stream_->Reset()); - TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos)); - } else if (curr_pos < desired_pos) { - TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos)); - } - DCHECK_EQ(desired_pos, input_stream_->Tell()); - // Read header data. - Status s = ReadChecksummed(*offset, sizeof(uint64), record); + StringPiece lbuf; + Status s = ReadChecksummed(*offset, sizeof(uint64), &lbuf, record); if (!s.ok()) { - last_read_failed_ = true; return s; } - const uint64 length = core::DecodeFixed64(record->data()); + const uint64 length = core::DecodeFixed64(lbuf.data()); // Read data - s = ReadChecksummed(*offset + kHeaderSize, length, record); + StringPiece data; + s = ReadChecksummed(*offset + kHeaderSize, length, &data, record); if (!s.ok()) { - last_read_failed_ = true; if (errors::IsOutOfRange(s)) { s = errors::DataLoss("truncated record at ", *offset); } return s; } + if (record->data() != data.data()) { + // RandomAccessFile placed the data in some other location. + memmove(&(*record)[0], data.data(), data.size()); + } + + record->resize(data.size()); + *offset += kHeaderSize + length + kFooterSize; - DCHECK_EQ(*offset, input_stream_->Tell()); return Status::OK(); } +Status RecordReader::SkipNBytes(uint64 offset) { +#if !defined(IS_SLIM_BUILD) + if (zlib_input_stream_) { + TF_RETURN_IF_ERROR(zlib_input_stream_->SkipNBytes(offset)); + } else { +#endif + if (options_.buffer_size > 0) { + TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(offset)); + } +#if !defined(IS_SLIM_BUILD) + } +#endif + return Status::OK(); +} // namespace io + SequentialRecordReader::SequentialRecordReader( RandomAccessFile* file, const RecordReaderOptions& options) : underlying_(file, options), offset_(0) {} |