diff options
author | Asim Shankar <ashankar@google.com> | 2018-04-19 23:08:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-19 23:11:36 -0700 |
commit | f7e8fbb28a0fa4e979a94d7b458706abf48f7deb (patch) | |
tree | 67760aaace78304e440f3e21e77fda395a65ea9e /tensorflow/core/lib/io | |
parent | 70b8d21edcc84818835c9e2940a5df288c309d45 (diff) |
Automated g4 rollback of changelist 193602050
PiperOrigin-RevId: 193625346
Diffstat (limited to 'tensorflow/core/lib/io')
-rw-r--r-- | tensorflow/core/lib/io/record_reader.cc | 147 | ||||
-rw-r--r-- | tensorflow/core/lib/io/record_reader.h | 16 | ||||
-rw-r--r-- | tensorflow/core/lib/io/recordio_test.cc | 212 | ||||
-rw-r--r-- | tensorflow/core/lib/io/zlib_inputstream.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/lib/io/zlib_inputstream.h | 19 |
5 files changed, 220 insertions, 190 deletions
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index 6de850bb20..c24628be57 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -56,110 +56,55 @@ RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions( RecordReader::RecordReader(RandomAccessFile* file, const RecordReaderOptions& options) - : src_(file), options_(options) { + : options_(options), + input_stream_(new RandomAccessInputStream(file)), + last_read_failed_(false) { if (options.buffer_size > 0) { - input_stream_.reset(new BufferedInputStream(file, options.buffer_size)); - } else { - input_stream_.reset(new RandomAccessInputStream(file)); + input_stream_.reset(new BufferedInputStream(input_stream_.release(), + options.buffer_size, true)); } if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) { // We don't have zlib available on all embedded platforms, so fail. #if defined(IS_SLIM_BUILD) LOG(FATAL) << "Zlib compression is unsupported on mobile platforms."; #else // IS_SLIM_BUILD - zlib_input_stream_.reset(new ZlibInputStream( - input_stream_.get(), options.zlib_options.input_buffer_size, - options.zlib_options.output_buffer_size, options.zlib_options)); + input_stream_.reset(new ZlibInputStream( + input_stream_.release(), options.zlib_options.input_buffer_size, + options.zlib_options.output_buffer_size, options.zlib_options, true)); #endif // IS_SLIM_BUILD } else if (options.compression_type == RecordReaderOptions::NONE) { // Nothing to do. } else { - LOG(FATAL) << "Unspecified compression type :" << options.compression_type; + LOG(FATAL) << "Unrecognized compression type :" << options.compression_type; } } // Read n+4 bytes from file, verify that checksum of first n bytes is // stored in the last 4 bytes and store the first n bytes in *result. -// May use *storage as backing store. -Status RecordReader::ReadChecksummed(uint64 offset, size_t n, - StringPiece* result, string* storage) { +// +// offset corresponds to the user-provided value to ReadRecord() +// and is used only in error messages. +Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) { if (n >= SIZE_MAX - sizeof(uint32)) { return errors::DataLoss("record size too large"); } const size_t expected = n + sizeof(uint32); - storage->resize(expected); - -#if !defined(IS_SLIM_BUILD) - if (zlib_input_stream_) { - // If we have a zlib compressed buffer, we assume that the - // file is being read sequentially, and we use the underlying - // implementation to read the data. - // - // No checks are done to validate that the file is being read - // sequentially. At some point the zlib input buffer may support - // seeking, possibly inefficiently. - TF_RETURN_IF_ERROR(zlib_input_stream_->ReadNBytes(expected, storage)); - - if (storage->size() != expected) { - if (storage->empty()) { - return errors::OutOfRange("eof"); - } else { - return errors::DataLoss("truncated record at ", offset); - } - } + TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, result)); - uint32 masked_crc = core::DecodeFixed32(storage->data() + n); - if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) { - return errors::DataLoss("corrupted record at ", offset); - } - *result = StringPiece(storage->data(), n); - } else { -#endif // IS_SLIM_BUILD - if (options_.buffer_size > 0) { - // If we have a buffer, we assume that the file is being read - // sequentially, and we use the underlying implementation to read the - // data. - // - // No checks are done to validate that the file is being read - // sequentially. - TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, storage)); - - if (storage->size() != expected) { - if (storage->empty()) { - return errors::OutOfRange("eof"); - } else { - return errors::DataLoss("truncated record at ", offset); - } - } - - const uint32 masked_crc = core::DecodeFixed32(storage->data() + n); - if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) { - return errors::DataLoss("corrupted record at ", offset); - } - *result = StringPiece(storage->data(), n); + if (result->size() != expected) { + if (result->empty()) { + return errors::OutOfRange("eof"); } else { - // This version supports reading from arbitrary offsets - // since we are accessing the random access file directly. - StringPiece data; - TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0])); - if (data.size() != expected) { - if (data.empty()) { - return errors::OutOfRange("eof"); - } else { - return errors::DataLoss("truncated record at ", offset); - } - } - const uint32 masked_crc = core::DecodeFixed32(data.data() + n); - if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) { - return errors::DataLoss("corrupted record at ", offset); - } - *result = StringPiece(data.data(), n); + return errors::DataLoss("truncated record at ", offset); } -#if !defined(IS_SLIM_BUILD) } -#endif // IS_SLIM_BUILD + const uint32 masked_crc = core::DecodeFixed32(result->data() + n); + if (crc32c::Unmask(masked_crc) != crc32c::Value(result->data(), n)) { + return errors::DataLoss("corrupted record at ", offset); + } + result->resize(n); return Status::OK(); } @@ -167,50 +112,42 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) { static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); static const size_t kFooterSize = sizeof(uint32); + // Position the input stream. + int64 curr_pos = input_stream_->Tell(); + int64 desired_pos = static_cast<int64>(*offset); + if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ || + (curr_pos == desired_pos && last_read_failed_)) { + last_read_failed_ = false; + TF_RETURN_IF_ERROR(input_stream_->Reset()); + TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos)); + } else if (curr_pos < desired_pos) { + TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos)); + } + DCHECK_EQ(desired_pos, input_stream_->Tell()); + // Read header data. - StringPiece lbuf; - Status s = ReadChecksummed(*offset, sizeof(uint64), &lbuf, record); + Status s = ReadChecksummed(*offset, sizeof(uint64), record); if (!s.ok()) { + last_read_failed_ = true; return s; } - const uint64 length = core::DecodeFixed64(lbuf.data()); + const uint64 length = core::DecodeFixed64(record->data()); // Read data - StringPiece data; - s = ReadChecksummed(*offset + kHeaderSize, length, &data, record); + s = ReadChecksummed(*offset + kHeaderSize, length, record); if (!s.ok()) { + last_read_failed_ = true; if (errors::IsOutOfRange(s)) { s = errors::DataLoss("truncated record at ", *offset); } return s; } - if (record->data() != data.data()) { - // RandomAccessFile placed the data in some other location. - memmove(&(*record)[0], data.data(), data.size()); - } - - record->resize(data.size()); - *offset += kHeaderSize + length + kFooterSize; + DCHECK_EQ(*offset, input_stream_->Tell()); return Status::OK(); } -Status RecordReader::SkipNBytes(uint64 offset) { -#if !defined(IS_SLIM_BUILD) - if (zlib_input_stream_) { - TF_RETURN_IF_ERROR(zlib_input_stream_->SkipNBytes(offset)); - } else { -#endif - if (options_.buffer_size > 0) { - TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(offset)); - } -#if !defined(IS_SLIM_BUILD) - } -#endif - return Status::OK(); -} // namespace io - SequentialRecordReader::SequentialRecordReader( RandomAccessFile* file, const RecordReaderOptions& options) : underlying_(file, options), offset_(0) {} diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index 26278e0328..f6d587dfa0 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -69,25 +69,14 @@ class RecordReader { // Read the record at "*offset" into *record and update *offset to // point to the offset of the next record. Returns OK on success, // OUT_OF_RANGE for end of file, or something else for an error. - // - // Note: if buffering is used (with or without compression), access must be - // sequential. Status ReadRecord(uint64* offset, string* record); - // Skip the records till "offset". Returns OK on success, - // OUT_OF_RANGE for end of file, or something else for an error. - Status SkipNBytes(uint64 offset); - private: - Status ReadChecksummed(uint64 offset, size_t n, StringPiece* result, - string* storage); + Status ReadChecksummed(uint64 offset, size_t n, string* result); - RandomAccessFile* src_; RecordReaderOptions options_; std::unique_ptr<InputStreamInterface> input_stream_; -#if !defined(IS_SLIM_BUILD) - std::unique_ptr<ZlibInputStream> zlib_input_stream_; -#endif // IS_SLIM_BUILD + bool last_read_failed_; TF_DISALLOW_COPY_AND_ASSIGN(RecordReader); }; @@ -121,7 +110,6 @@ class SequentialRecordReader { return errors::InvalidArgument( "Trying to seek offset: ", offset, " which is less than the current offset: ", offset_); - TF_RETURN_IF_ERROR(underlying_.SkipNBytes(offset - offset_)); offset_ = offset; return Status::OK(); } diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc index 63235761d9..da514bd21c 100644 --- a/tensorflow/core/lib/io/recordio_test.cc +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -26,10 +26,11 @@ limitations under the License. namespace tensorflow { namespace io { +namespace { // Construct a string of the specified length made out of the supplied // partial string. -static string BigString(const string& partial_string, size_t n) { +string BigString(const string& partial_string, size_t n) { string result; while (result.size() < n) { result.append(partial_string); @@ -39,62 +40,66 @@ static string BigString(const string& partial_string, size_t n) { } // Construct a string from a number -static string NumberString(int n) { +string NumberString(int n) { char buf[50]; snprintf(buf, sizeof(buf), "%d.", n); return string(buf); } // Return a skewed potentially long string -static string RandomSkewedString(int i, random::SimplePhilox* rnd) { +string RandomSkewedString(int i, random::SimplePhilox* rnd) { return BigString(NumberString(i), rnd->Skewed(17)); } -class RecordioTest : public ::testing::Test { +class StringDest : public WritableFile { + public: + explicit StringDest(string* contents) : contents_(contents) {} + + Status Close() override { return Status::OK(); } + Status Flush() override { return Status::OK(); } + Status Sync() override { return Status::OK(); } + Status Append(const StringPiece& slice) override { + contents_->append(slice.data(), slice.size()); + return Status::OK(); + } + private: - class StringDest : public WritableFile { - public: - string contents_; - - Status Close() override { return Status::OK(); } - Status Flush() override { return Status::OK(); } - Status Sync() override { return Status::OK(); } - Status Append(const StringPiece& slice) override { - contents_.append(slice.data(), slice.size()); - return Status::OK(); + string* contents_; +}; + +class StringSource : public RandomAccessFile { + public: + explicit StringSource(string* contents) + : contents_(contents), force_error_(false) {} + + Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { + if (force_error_) { + force_error_ = false; + return errors::DataLoss("read error"); } - }; - - class StringSource : public RandomAccessFile { - public: - StringPiece contents_; - mutable bool force_error_; - mutable bool returned_partial_; - StringSource() : force_error_(false), returned_partial_(false) {} - - Status Read(uint64 offset, size_t n, StringPiece* result, - char* scratch) const override { - EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error"; - - if (force_error_) { - force_error_ = false; - returned_partial_ = true; - return errors::DataLoss("read error"); - } - - if (offset >= contents_.size()) { - return errors::OutOfRange("end of file"); - } - - if (contents_.size() < offset + n) { - n = contents_.size() - offset; - returned_partial_ = true; - } - *result = StringPiece(contents_.data() + offset, n); - return Status::OK(); + + if (offset >= contents_->size()) { + return errors::OutOfRange("end of file"); + } + + if (contents_->size() < offset + n) { + n = contents_->size() - offset; } - }; + *result = StringPiece(contents_->data() + offset, n); + return Status::OK(); + } + + void force_error() { force_error_ = true; } + + private: + string* contents_; + mutable bool force_error_; +}; +class RecordioTest : public ::testing::Test { + private: + string contents_; StringDest dest_; StringSource source_; bool reading_; @@ -104,7 +109,9 @@ class RecordioTest : public ::testing::Test { public: RecordioTest() - : reading_(false), + : dest_(&contents_), + source_(&contents_), + reading_(false), readpos_(0), writer_(new RecordWriter(&dest_)), reader_(new RecordReader(&source_)) {} @@ -119,12 +126,11 @@ class RecordioTest : public ::testing::Test { TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg))); } - size_t WrittenBytes() const { return dest_.contents_.size(); } + size_t WrittenBytes() const { return contents_.size(); } string Read() { if (!reading_) { reading_ = true; - source_.contents_ = StringPiece(dest_.contents_); } string record; Status s = reader_->ReadRecord(&readpos_, &record); @@ -137,26 +143,20 @@ class RecordioTest : public ::testing::Test { } } - void IncrementByte(int offset, int delta) { - dest_.contents_[offset] += delta; - } + void IncrementByte(int offset, int delta) { contents_[offset] += delta; } - void SetByte(int offset, char new_byte) { - dest_.contents_[offset] = new_byte; - } + void SetByte(int offset, char new_byte) { contents_[offset] = new_byte; } - void ShrinkSize(int bytes) { - dest_.contents_.resize(dest_.contents_.size() - bytes); - } + void ShrinkSize(int bytes) { contents_.resize(contents_.size() - bytes); } void FixChecksum(int header_offset, int len) { // Compute crc of type/len/data - uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len); + uint32_t crc = crc32c::Value(&contents_[header_offset + 6], 1 + len); crc = crc32c::Mask(crc); - core::EncodeFixed32(&dest_.contents_[header_offset], crc); + core::EncodeFixed32(&contents_[header_offset], crc); } - void ForceError() { source_.force_error_ = true; } + void ForceError() { source_.force_error(); } void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; } @@ -165,7 +165,6 @@ class RecordioTest : public ::testing::Test { Write("bar"); Write(BigString("x", 10000)); reading_ = true; - source_.contents_ = StringPiece(dest_.contents_); uint64 offset = WrittenBytes() + offset_past_end; string record; Status s = reader_->ReadRecord(&offset, &record); @@ -217,16 +216,100 @@ TEST_F(RecordioTest, RandomRead) { ASSERT_EQ("EOF", Read()); } +void TestNonSequentialReads(const RecordWriterOptions& writer_options, + const RecordReaderOptions& reader_options) { + string contents; + StringDest dst(&contents); + RecordWriter writer(&dst, writer_options); + for (int i = 0; i < 10; ++i) { + TF_ASSERT_OK(writer.WriteRecord(NumberString(i))) << i; + } + TF_ASSERT_OK(writer.Close()); + + StringSource file(&contents); + RecordReader reader(&file, reader_options); + + string record; + // First read sequentially to fill in the offsets table. + uint64 offsets[10] = {0}; + uint64 offset = 0; + for (int i = 0; i < 10; ++i) { + offsets[i] = offset; + TF_ASSERT_OK(reader.ReadRecord(&offset, &record)) << i; + } + + // Read randomly: First go back to record #3 then forward to #8. + offset = offsets[3]; + TF_ASSERT_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("3.", record); + EXPECT_EQ(offsets[4], offset); + + offset = offsets[8]; + TF_ASSERT_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("8.", record); + EXPECT_EQ(offsets[9], offset); +} + +TEST_F(RecordioTest, NonSequentialReads) { + TestNonSequentialReads(RecordWriterOptions(), RecordReaderOptions()); +} + +TEST_F(RecordioTest, NonSequentialReadsWithReadBuffer) { + RecordReaderOptions options; + options.buffer_size = 1 << 10; + TestNonSequentialReads(RecordWriterOptions(), options); +} + +TEST_F(RecordioTest, NonSequentialReadsWithCompression) { + TestNonSequentialReads( + RecordWriterOptions::CreateRecordWriterOptions("ZLIB"), + RecordReaderOptions::CreateRecordReaderOptions("ZLIB")); +} + // Tests of all the error paths in log_reader.cc follow: -static void AssertHasSubstr(StringPiece s, StringPiece expected) { +void AssertHasSubstr(StringPiece s, StringPiece expected) { EXPECT_TRUE(str_util::StrContains(s, expected)) << s << " does not contain " << expected; } +void TestReadError(const RecordWriterOptions& writer_options, + const RecordReaderOptions& reader_options) { + const string wrote = BigString("well hello there!", 100); + string contents; + StringDest dst(&contents); + TF_ASSERT_OK(RecordWriter(&dst, writer_options).WriteRecord(wrote)); + + StringSource file(&contents); + RecordReader reader(&file, reader_options); + + uint64 offset = 0; + string read; + file.force_error(); + Status status = reader.ReadRecord(&offset, &read); + ASSERT_TRUE(errors::IsDataLoss(status)); + ASSERT_EQ(0, offset); + + // A failed Read() shouldn't update the offset, and thus a retry shouldn't + // lose the record. + status = reader.ReadRecord(&offset, &read); + ASSERT_TRUE(status.ok()) << status; + EXPECT_GT(offset, 0); + EXPECT_EQ(wrote, read); +} + TEST_F(RecordioTest, ReadError) { - Write("foo"); - ForceError(); - AssertHasSubstr(Read(), "Data loss"); + TestReadError(RecordWriterOptions(), RecordReaderOptions()); +} + +TEST_F(RecordioTest, ReadErrorWithBuffering) { + RecordReaderOptions options; + options.buffer_size = 1 << 20; + TestReadError(RecordWriterOptions(), options); +} + +TEST_F(RecordioTest, ReadErrorWithCompression) { + TestReadError(RecordWriterOptions::CreateRecordWriterOptions("ZLIB"), + RecordReaderOptions::CreateRecordReaderOptions("ZLIB")); } TEST_F(RecordioTest, CorruptLength) { @@ -257,5 +340,6 @@ TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); } TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); } +} // namespace } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/zlib_inputstream.cc b/tensorflow/core/lib/io/zlib_inputstream.cc index 984fbc2810..47de36bf6c 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.cc +++ b/tensorflow/core/lib/io/zlib_inputstream.cc @@ -25,8 +25,9 @@ ZlibInputStream::ZlibInputStream( InputStreamInterface* input_stream, size_t input_buffer_bytes, // size of z_stream.next_in buffer size_t output_buffer_bytes, // size of z_stream.next_out buffer - const ZlibCompressionOptions& zlib_options) - : input_stream_(input_stream), + const ZlibCompressionOptions& zlib_options, bool owns_input_stream) + : owns_input_stream_(owns_input_stream), + input_stream_(input_stream), input_buffer_capacity_(input_buffer_bytes), output_buffer_capacity_(output_buffer_bytes), z_stream_input_(new Bytef[input_buffer_capacity_]), @@ -37,14 +38,25 @@ ZlibInputStream::ZlibInputStream( InitZlibBuffer(); } +ZlibInputStream::ZlibInputStream(InputStreamInterface* input_stream, + size_t input_buffer_bytes, + size_t output_buffer_bytes, + const ZlibCompressionOptions& zlib_options) + : ZlibInputStream(input_stream, input_buffer_bytes, output_buffer_bytes, + zlib_options, false) {} + ZlibInputStream::~ZlibInputStream() { if (z_stream_) { inflateEnd(z_stream_.get()); } + if (owns_input_stream_) { + delete input_stream_; + } } Status ZlibInputStream::Reset() { TF_RETURN_IF_ERROR(input_stream_->Reset()); + inflateEnd(z_stream_.get()); InitZlibBuffer(); bytes_read_ = 0; return Status::OK(); diff --git a/tensorflow/core/lib/io/zlib_inputstream.h b/tensorflow/core/lib/io/zlib_inputstream.h index 9c7e14441c..37339163ee 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.h +++ b/tensorflow/core/lib/io/zlib_inputstream.h @@ -40,7 +40,15 @@ class ZlibInputStream : public InputStreamInterface { // Create a ZlibInputStream for `input_stream` with a buffer of size // `input_buffer_bytes` bytes for reading contents from `input_stream` and // another buffer with size `output_buffer_bytes` for caching decompressed - // contents. Does *not* take ownership of "input_stream". + // contents. + // + // Takes ownership of `input_stream` iff `owns_input_stream` is true. + ZlibInputStream(InputStreamInterface* input_stream, size_t input_buffer_bytes, + size_t output_buffer_bytes, + const ZlibCompressionOptions& zlib_options, + bool owns_input_stream); + + // Equivalent to the previous constructor with owns_input_stream=false. ZlibInputStream(InputStreamInterface* input_stream, size_t input_buffer_bytes, size_t output_buffer_bytes, const ZlibCompressionOptions& zlib_options); @@ -65,10 +73,11 @@ class ZlibInputStream : public InputStreamInterface { private: void InitZlibBuffer(); - InputStreamInterface* input_stream_; // Not owned - size_t input_buffer_capacity_; // Size of z_stream_input_ - size_t output_buffer_capacity_; // Size of z_stream_output_ - char* next_unread_byte_; // Next unread byte in z_stream_output_ + const bool owns_input_stream_; + InputStreamInterface* input_stream_; + size_t input_buffer_capacity_; // Size of z_stream_input_ + size_t output_buffer_capacity_; // Size of z_stream_output_ + char* next_unread_byte_; // Next unread byte in z_stream_output_ // Buffer for storing contents read from compressed stream. // TODO(srbs): Consider using circular buffers. That would greatly simplify |