aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/io
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2017-08-23 13:26:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-23 13:30:42 -0700
commite2f549e668bdb71ec2fadd49a9b42143048e277e (patch)
treeb5836afe114b64d1889e33ed9d769ba26508efd0 /tensorflow/core/lib/io
parent84138ef8bbb258ba5822808e19c54865f345cbef (diff)
Add buffering to RecordReader, and reader datasets.
Currently RecordReader reads from random access files. When reading from high-latency data stores (such as GCS), it's often preferable to buffer reads. This change adds buffering support to the reader datasets, and passes the configurations down to their respective backend implementations. PiperOrigin-RevId: 166245223
Diffstat (limited to 'tensorflow/core/lib/io')
-rw-r--r--tensorflow/core/lib/io/record_reader.cc70
-rw-r--r--tensorflow/core/lib/io/record_reader.h48
2 files changed, 92 insertions, 26 deletions
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index ff2fd48de9..c3b87ee5bf 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/compression.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/platform/env.h"
@@ -56,14 +57,18 @@ RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions(
RecordReader::RecordReader(RandomAccessFile* file,
const RecordReaderOptions& options)
: src_(file), options_(options) {
+ if (options.buffer_size > 0) {
+ input_stream_.reset(new BufferedInputStream(file, options.buffer_size));
+ } else {
+ input_stream_.reset(new RandomAccessInputStream(file));
+ }
if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
// We don't have zlib available on all embedded platforms, so fail.
#if defined(IS_SLIM_BUILD)
LOG(FATAL) << "Zlib compression is unsupported on mobile platforms.";
#else // IS_SLIM_BUILD
- random_input_stream_.reset(new RandomAccessInputStream(file));
zlib_input_stream_.reset(new ZlibInputStream(
- random_input_stream_.get(), options.zlib_options.input_buffer_size,
+ input_stream_.get(), options.zlib_options.input_buffer_size,
options.zlib_options.output_buffer_size, options.zlib_options));
#endif // IS_SLIM_BUILD
} else if (options.compression_type == RecordReaderOptions::NONE) {
@@ -73,11 +78,6 @@ RecordReader::RecordReader(RandomAccessFile* file,
}
}
-RecordReader::~RecordReader() {
- zlib_input_stream_.reset(nullptr);
- random_input_stream_.reset(nullptr);
-}
-
// Read n+4 bytes from file, verify that checksum of first n bytes is
// stored in the last 4 bytes and store the first n bytes in *result.
// May use *storage as backing store.
@@ -116,22 +116,46 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n,
*result = StringPiece(storage->data(), n);
} else {
#endif // IS_SLIM_BUILD
- // This version supports reading from arbitrary offsets
- // since we are accessing the random access file directly.
- StringPiece data;
- TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0]));
- if (data.size() != expected) {
- if (data.empty()) {
- return errors::OutOfRange("eof");
- } else {
- return errors::DataLoss("truncated record at ", offset);
+ if (options_.buffer_size > 0) {
+ // If we have a buffer, we assume that the file is being read
+ // sequentially, and we use the underlying implementation to read the
+ // data.
+ //
+ // No checks are done to validate that the file is being read
+ // sequentially.
+ TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, storage));
+
+ if (storage->size() != expected) {
+ if (storage->empty()) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
}
+
+ const uint32 masked_crc = core::DecodeFixed32(storage->data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(storage->data(), n);
+ } else {
+ // This version supports reading from arbitrary offsets
+ // since we are accessing the random access file directly.
+ StringPiece data;
+ TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0]));
+ if (data.size() != expected) {
+ if (data.empty()) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
+ }
+ const uint32 masked_crc = core::DecodeFixed32(data.data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(data.data(), n);
}
- uint32 masked_crc = core::DecodeFixed32(data.data() + n);
- if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
- return errors::DataLoss("corrupted record at ", offset);
- }
- *result = StringPiece(data.data(), n);
#if !defined(IS_SLIM_BUILD)
}
#endif // IS_SLIM_BUILD
@@ -172,5 +196,9 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) {
return Status::OK();
}
+SequentialRecordReader::SequentialRecordReader(
+ RandomAccessFile* file, const RecordReaderOptions& options)
+ : underlying_(file, options), offset_(0) {}
+
} // namespace io
} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index 6c92b14963..e4f6a5b492 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#if !defined(IS_SLIM_BUILD)
-#include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_inputstream.h"
#endif // IS_SLIM_BUILD
@@ -37,6 +37,11 @@ class RecordReaderOptions {
enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 };
CompressionType compression_type = NONE;
+ // If buffer_size is non-zero, then all reads must be sequential, and no
+ // skipping around is permitted. (Note: this is the same behavior as reading
+ // compressed files.) Consider using SequentialRecordReader.
+ int64 buffer_size = 0;
+
static RecordReaderOptions CreateRecordReaderOptions(
const string& compression_type);
@@ -46,18 +51,27 @@ class RecordReaderOptions {
#endif // IS_SLIM_BUILD
};
+// Low-level interface to read TFRecord files.
+//
+// If using compression or buffering, consider using SequentialRecordReader.
+//
+// Note: this class is not thread safe; external synchronization required.
class RecordReader {
public:
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
- RecordReader(RandomAccessFile* file,
- const RecordReaderOptions& options = RecordReaderOptions());
+ explicit RecordReader(
+ RandomAccessFile* file,
+ const RecordReaderOptions& options = RecordReaderOptions());
- virtual ~RecordReader();
+ virtual ~RecordReader() = default;
// Read the record at "*offset" into *record and update *offset to
// point to the offset of the next record. Returns OK on success,
// OUT_OF_RANGE for end of file, or something else for an error.
+ //
+ // Note: if buffering is used (with or without compression), access must be
+ // sequential.
Status ReadRecord(uint64* offset, string* record);
private:
@@ -66,14 +80,38 @@ class RecordReader {
RandomAccessFile* src_;
RecordReaderOptions options_;
+ std::unique_ptr<InputStreamInterface> input_stream_;
#if !defined(IS_SLIM_BUILD)
- std::unique_ptr<RandomAccessInputStream> random_input_stream_;
std::unique_ptr<ZlibInputStream> zlib_input_stream_;
#endif // IS_SLIM_BUILD
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
+// High-level interface to read TFRecord files.
+//
+// Note: this class is not thread safe; external synchronization required.
+class SequentialRecordReader {
+ public:
+ // Create a reader that will return log records from "*file".
+ // "*file" must remain live while this Reader is in use.
+ explicit SequentialRecordReader(
+ RandomAccessFile* file,
+ const RecordReaderOptions& options = RecordReaderOptions());
+
+ virtual ~SequentialRecordReader() = default;
+
+ // Reads the next record in the file into *record. Returns OK on success,
+ // OUT_OF_RANGE for end of file, or something else for an error.
+ Status ReadRecord(string* record) {
+ return underlying_.ReadRecord(&offset_, record);
+ }
+
+ private:
+ RecordReader underlying_;
+ uint64 offset_ = 0;
+};
+
} // namespace io
} // namespace tensorflow