diff options
Diffstat (limited to 'tensorflow/core/lib')
-rw-r--r-- | tensorflow/core/lib/io/record_reader.cc | 53 | ||||
-rw-r--r-- | tensorflow/core/lib/io/record_reader.h | 25 | ||||
-rw-r--r-- | tensorflow/core/lib/io/record_reader_writer_test.cc | 7 |
3 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index f93ebea771..e22adcd569 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -108,6 +108,59 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) { return Status::OK(); } +Status RecordReader::GetMetadata(Metadata* md) { + if (!md) { + return errors::InvalidArgument( + "Metadata object call to GetMetadata() was null"); + } + + // Compute the metadata of the TFRecord file if not cached. + if (!cached_metadata_) { + TF_RETURN_IF_ERROR(input_stream_->Reset()); + + int64 data_size = 0; + int64 entries = 0; + + // Within the loop, we always increment offset positively, so this + // loop should be guaranteed to either return after reaching EOF + // or encountering an error. + uint64 offset = 0; + string record; + while (true) { + // Read header, containing size of data. + Status s = ReadChecksummed(offset, sizeof(uint64), &record); + if (!s.ok()) { + if (errors::IsOutOfRange(s)) { + // We should reach out of range when the record file is complete. + break; + } + return s; + } + + // Read the length of the data. + const uint64 length = core::DecodeFixed64(record.data()); + + // Skip reading the actual data since we just want the number + // of records and the size of the data. + TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize)); + offset += kHeaderSize + length + kFooterSize; + + // Increment running stats. + data_size += length; + ++entries; + } + + cached_metadata_.reset(new Metadata()); + cached_metadata_->stats.entries = entries; + cached_metadata_->stats.data_size = data_size; + cached_metadata_->stats.file_size = + data_size + (kHeaderSize + kFooterSize) * entries; + } + + md->stats = cached_metadata_->stats; + return Status::OK(); +} + Status RecordReader::ReadRecord(uint64* offset, string* record) { // Position the input stream. int64 curr_pos = input_stream_->Tell(); diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index 11af1366b0..17444660d4 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -66,6 +66,18 @@ class RecordReader { static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); static const size_t kFooterSize = sizeof(uint32); + // Statistics (sizes are in units of bytes) + struct Stats { + int64 file_size = -1; + int64 data_size = -1; + int64 entries = -1; // Number of values + }; + + // Metadata for the TFRecord file. + struct Metadata { + Stats stats; + }; + // Create a reader that will return log records from "*file". // "*file" must remain live while this Reader is in use. explicit RecordReader( @@ -79,6 +91,17 @@ class RecordReader { // OUT_OF_RANGE for end of file, or something else for an error. Status ReadRecord(uint64* offset, string* record); + // Return the metadata of the Record file. + // + // The current implementation scans the file to completion, + // skipping over the data regions, to extract the metadata once + // on the first call to GetStats(). An improved implementation + // would change RecordWriter to write the metadata into TFRecord + // so that GetMetadata() could be a const method. + // + // 'metadata' must not be nullptr. + Status GetMetadata(Metadata* md); + private: Status ReadChecksummed(uint64 offset, size_t n, string* result); @@ -86,6 +109,8 @@ class RecordReader { std::unique_ptr<InputStreamInterface> input_stream_; bool last_read_failed_; + std::unique_ptr<Metadata> cached_metadata_; + TF_DISALLOW_COPY_AND_ASSIGN(RecordReader); }; diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc index 13bea1f8f1..a88d34d293 100644 --- a/tensorflow/core/lib/io/record_reader_writer_test.cc +++ b/tensorflow/core/lib/io/record_reader_writer_test.cc @@ -147,6 +147,13 @@ TEST(RecordReaderWriterTest, TestBasics) { EXPECT_EQ("abc", record); TF_CHECK_OK(reader.ReadRecord(&offset, &record)); EXPECT_EQ("defg", record); + + io::RecordReader::Metadata md; + TF_ASSERT_OK(reader.GetMetadata(&md)); + EXPECT_EQ(2, md.stats.entries); + EXPECT_EQ(7, md.stats.data_size); + // Two entries have 16 bytes of header/footer each. + EXPECT_EQ(39, md.stats.file_size); } } } |