aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2018-09-23 10:51:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-23 10:54:51 -0700
commit5b6a09f81f8088626b5d88ed7fe3f3414d7ae23e (patch)
tree42866a67e9bcd652a68ed61703ad8c8672f7f5da /tensorflow/core/lib
parentbb5fbbcb663c795dd7fc16e43a0eaaae53231fd9 (diff)
Add Metadata object for RecordReader and associated stats computing /
MD fetching method. TFRecord files do not contain a file-level header that describes the MD of the file. To avoid backwards compatibility issues, we add a lightweight function that computes the statistics over the file once and caches the result for future calls. A future implementor could do a better job of computing GetMetadata() by having the RecordWriter emit these entries during writing so that GetMetadata() only reads. Doing so will require additional backwards compatibility checks to ensure that the function works both for old TFRecords and the new format. PiperOrigin-RevId: 214178704
Diffstat (limited to 'tensorflow/core/lib')
-rw-r--r--tensorflow/core/lib/io/record_reader.cc53
-rw-r--r--tensorflow/core/lib/io/record_reader.h25
-rw-r--r--tensorflow/core/lib/io/record_reader_writer_test.cc7
3 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index f93ebea771..e22adcd569 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -108,6 +108,59 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
return Status::OK();
}
+Status RecordReader::GetMetadata(Metadata* md) {
+ if (!md) {
+ return errors::InvalidArgument(
+ "Metadata object call to GetMetadata() was null");
+ }
+
+ // Compute the metadata of the TFRecord file if not cached.
+ if (!cached_metadata_) {
+ TF_RETURN_IF_ERROR(input_stream_->Reset());
+
+ int64 data_size = 0;
+ int64 entries = 0;
+
+ // Within the loop, we always increment offset positively, so this
+ // loop should be guaranteed to either return after reaching EOF
+ // or encountering an error.
+ uint64 offset = 0;
+ string record;
+ while (true) {
+ // Read header, containing size of data.
+ Status s = ReadChecksummed(offset, sizeof(uint64), &record);
+ if (!s.ok()) {
+ if (errors::IsOutOfRange(s)) {
+ // We should reach out of range when the record file is complete.
+ break;
+ }
+ return s;
+ }
+
+ // Read the length of the data.
+ const uint64 length = core::DecodeFixed64(record.data());
+
+ // Skip reading the actual data since we just want the number
+ // of records and the size of the data.
+ TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
+ offset += kHeaderSize + length + kFooterSize;
+
+ // Increment running stats.
+ data_size += length;
+ ++entries;
+ }
+
+ cached_metadata_.reset(new Metadata());
+ cached_metadata_->stats.entries = entries;
+ cached_metadata_->stats.data_size = data_size;
+ cached_metadata_->stats.file_size =
+ data_size + (kHeaderSize + kFooterSize) * entries;
+ }
+
+ md->stats = cached_metadata_->stats;
+ return Status::OK();
+}
+
Status RecordReader::ReadRecord(uint64* offset, string* record) {
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index 11af1366b0..17444660d4 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -66,6 +66,18 @@ class RecordReader {
static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static const size_t kFooterSize = sizeof(uint32);
+ // Statistics (sizes are in units of bytes)
+ struct Stats {
+ int64 file_size = -1;
+ int64 data_size = -1;
+ int64 entries = -1; // Number of values
+ };
+
+ // Metadata for the TFRecord file.
+ struct Metadata {
+ Stats stats;
+ };
+
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
explicit RecordReader(
@@ -79,6 +91,17 @@ class RecordReader {
// OUT_OF_RANGE for end of file, or something else for an error.
Status ReadRecord(uint64* offset, string* record);
+ // Return the metadata of the Record file.
+ //
+ // The current implementation scans the file to completion,
+ // skipping over the data regions, to extract the metadata once
+ // on the first call to GetStats(). An improved implementation
+ // would change RecordWriter to write the metadata into TFRecord
+ // so that GetMetadata() could be a const method.
+ //
+ // 'metadata' must not be nullptr.
+ Status GetMetadata(Metadata* md);
+
private:
Status ReadChecksummed(uint64 offset, size_t n, string* result);
@@ -86,6 +109,8 @@ class RecordReader {
std::unique_ptr<InputStreamInterface> input_stream_;
bool last_read_failed_;
+ std::unique_ptr<Metadata> cached_metadata_;
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc
index 13bea1f8f1..a88d34d293 100644
--- a/tensorflow/core/lib/io/record_reader_writer_test.cc
+++ b/tensorflow/core/lib/io/record_reader_writer_test.cc
@@ -147,6 +147,13 @@ TEST(RecordReaderWriterTest, TestBasics) {
EXPECT_EQ("abc", record);
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("defg", record);
+
+ io::RecordReader::Metadata md;
+ TF_ASSERT_OK(reader.GetMetadata(&md));
+ EXPECT_EQ(2, md.stats.entries);
+ EXPECT_EQ(7, md.stats.data_size);
+ // Two entries have 16 bytes of header/footer each.
+ EXPECT_EQ(39, md.stats.file_size);
}
}
}