aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib')
-rw-r--r--tensorflow/core/lib/io/record_reader.cc53
-rw-r--r--tensorflow/core/lib/io/record_reader.h25
-rw-r--r--tensorflow/core/lib/io/record_reader_writer_test.cc7
3 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index f93ebea771..e22adcd569 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -108,6 +108,59 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
return Status::OK();
}
+Status RecordReader::GetMetadata(Metadata* md) {
+ if (!md) {
+ return errors::InvalidArgument(
+ "Metadata object call to GetMetadata() was null");
+ }
+
+ // Compute the metadata of the TFRecord file if not cached.
+ if (!cached_metadata_) {
+ TF_RETURN_IF_ERROR(input_stream_->Reset());
+
+ int64 data_size = 0;
+ int64 entries = 0;
+
+ // Within the loop, we always increment offset positively, so this
+ // loop should be guaranteed to either return after reaching EOF
+ // or encountering an error.
+ uint64 offset = 0;
+ string record;
+ while (true) {
+ // Read header, containing size of data.
+ Status s = ReadChecksummed(offset, sizeof(uint64), &record);
+ if (!s.ok()) {
+ if (errors::IsOutOfRange(s)) {
+ // We should reach out of range when the record file is complete.
+ break;
+ }
+ return s;
+ }
+
+ // Read the length of the data.
+ const uint64 length = core::DecodeFixed64(record.data());
+
+ // Skip reading the actual data since we just want the number
+ // of records and the size of the data.
+ TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
+ offset += kHeaderSize + length + kFooterSize;
+
+ // Increment running stats.
+ data_size += length;
+ ++entries;
+ }
+
+ cached_metadata_.reset(new Metadata());
+ cached_metadata_->stats.entries = entries;
+ cached_metadata_->stats.data_size = data_size;
+ cached_metadata_->stats.file_size =
+ data_size + (kHeaderSize + kFooterSize) * entries;
+ }
+
+ md->stats = cached_metadata_->stats;
+ return Status::OK();
+}
+
Status RecordReader::ReadRecord(uint64* offset, string* record) {
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index 11af1366b0..17444660d4 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -66,6 +66,18 @@ class RecordReader {
static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static const size_t kFooterSize = sizeof(uint32);
+ // Statistics (sizes are in units of bytes)
+ struct Stats {
+ int64 file_size = -1;
+ int64 data_size = -1;
+ int64 entries = -1; // Number of values
+ };
+
+ // Metadata for the TFRecord file.
+ struct Metadata {
+ Stats stats;
+ };
+
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
explicit RecordReader(
@@ -79,6 +91,17 @@ class RecordReader {
// OUT_OF_RANGE for end of file, or something else for an error.
Status ReadRecord(uint64* offset, string* record);
+ // Return the metadata of the Record file.
+ //
+ // The current implementation scans the file to completion,
+ // skipping over the data regions, to extract the metadata once
+ // on the first call to GetStats(). An improved implementation
+ // would change RecordWriter to write the metadata into TFRecord
+ // so that GetMetadata() could be a const method.
+ //
+ // 'metadata' must not be nullptr.
+ Status GetMetadata(Metadata* md);
+
private:
Status ReadChecksummed(uint64 offset, size_t n, string* result);
@@ -86,6 +109,8 @@ class RecordReader {
std::unique_ptr<InputStreamInterface> input_stream_;
bool last_read_failed_;
+ std::unique_ptr<Metadata> cached_metadata_;
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc
index 13bea1f8f1..a88d34d293 100644
--- a/tensorflow/core/lib/io/record_reader_writer_test.cc
+++ b/tensorflow/core/lib/io/record_reader_writer_test.cc
@@ -147,6 +147,13 @@ TEST(RecordReaderWriterTest, TestBasics) {
EXPECT_EQ("abc", record);
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("defg", record);
+
+ io::RecordReader::Metadata md;
+ TF_ASSERT_OK(reader.GetMetadata(&md));
+ EXPECT_EQ(2, md.stats.entries);
+ EXPECT_EQ(7, md.stats.data_size);
+ // Two entries have 16 bytes of header/footer each.
+ EXPECT_EQ(39, md.stats.file_size);
}
}
}