aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/io/recordio_test.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/lib/io/recordio_test.cc
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/core/lib/io/recordio_test.cc')
-rw-r--r--tensorflow/core/lib/io/recordio_test.cc245
1 files changed, 245 insertions, 0 deletions
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc
new file mode 100644
index 0000000000..3e9c816443
--- /dev/null
+++ b/tensorflow/core/lib/io/recordio_test.cc
@@ -0,0 +1,245 @@
+#include "tensorflow/core/lib/io/record_reader.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace io {
+
+// Construct a string of the specified length made out of the supplied
+// partial string.
+static string BigString(const string& partial_string, size_t n) {
+ string result;
+ while (result.size() < n) {
+ result.append(partial_string);
+ }
+ result.resize(n);
+ return result;
+}
+
+// Construct a string from a number
+static string NumberString(int n) {
+ char buf[50];
+ snprintf(buf, sizeof(buf), "%d.", n);
+ return string(buf);
+}
+
+// Return a skewed potentially long string
+static string RandomSkewedString(int i, random::SimplePhilox* rnd) {
+ return BigString(NumberString(i), rnd->Skewed(17));
+}
+
+class RecordioTest : public testing::Test {
+ private:
+ class StringDest : public WritableFile {
+ public:
+ string contents_;
+
+ Status Close() override { return Status::OK(); }
+ Status Flush() override { return Status::OK(); }
+ Status Sync() override { return Status::OK(); }
+ Status Append(const StringPiece& slice) override {
+ contents_.append(slice.data(), slice.size());
+ return Status::OK();
+ }
+ };
+
+ class StringSource : public RandomAccessFile {
+ public:
+ StringPiece contents_;
+ mutable bool force_error_;
+ mutable bool returned_partial_;
+ StringSource() : force_error_(false), returned_partial_(false) {}
+
+ Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const override {
+ EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error";
+
+ if (force_error_) {
+ force_error_ = false;
+ returned_partial_ = true;
+ return errors::DataLoss("read error");
+ }
+
+ if (offset >= contents_.size()) {
+ return errors::OutOfRange("end of file");
+ }
+
+ if (contents_.size() < offset + n) {
+ n = contents_.size() - offset;
+ returned_partial_ = true;
+ }
+ *result = StringPiece(contents_.data() + offset, n);
+ return Status::OK();
+ }
+ };
+
+ StringDest dest_;
+ StringSource source_;
+ bool reading_;
+ uint64 readpos_;
+ RecordWriter* writer_;
+ RecordReader* reader_;
+
+ public:
+ RecordioTest()
+ : reading_(false),
+ readpos_(0),
+ writer_(new RecordWriter(&dest_)),
+ reader_(new RecordReader(&source_)) {}
+
+ ~RecordioTest() override {
+ delete writer_;
+ delete reader_;
+ }
+
+ void Write(const string& msg) {
+ ASSERT_TRUE(!reading_) << "Write() after starting to read";
+ ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
+ }
+
+ size_t WrittenBytes() const { return dest_.contents_.size(); }
+
+ string Read() {
+ if (!reading_) {
+ reading_ = true;
+ source_.contents_ = StringPiece(dest_.contents_);
+ }
+ string record;
+ Status s = reader_->ReadRecord(&readpos_, &record);
+ if (s.ok()) {
+ return record;
+ } else if (errors::IsOutOfRange(s)) {
+ return "EOF";
+ } else {
+ return s.ToString();
+ }
+ }
+
+ void IncrementByte(int offset, int delta) {
+ dest_.contents_[offset] += delta;
+ }
+
+ void SetByte(int offset, char new_byte) {
+ dest_.contents_[offset] = new_byte;
+ }
+
+ void ShrinkSize(int bytes) {
+ dest_.contents_.resize(dest_.contents_.size() - bytes);
+ }
+
+ void FixChecksum(int header_offset, int len) {
+ // Compute crc of type/len/data
+ uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len);
+ crc = crc32c::Mask(crc);
+ core::EncodeFixed32(&dest_.contents_[header_offset], crc);
+ }
+
+ void ForceError() { source_.force_error_ = true; }
+
+ void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
+
+ void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) {
+ Write("foo");
+ Write("bar");
+ Write(BigString("x", 10000));
+ reading_ = true;
+ source_.contents_ = StringPiece(dest_.contents_);
+ uint64 offset = WrittenBytes() + offset_past_end;
+ string record;
+ Status s = reader_->ReadRecord(&offset, &record);
+ ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
+ }
+};
+
+TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); }
+
+TEST_F(RecordioTest, ReadWrite) {
+ Write("foo");
+ Write("bar");
+ Write("");
+ Write("xxxx");
+ ASSERT_EQ("foo", Read());
+ ASSERT_EQ("bar", Read());
+ ASSERT_EQ("", Read());
+ ASSERT_EQ("xxxx", Read());
+ ASSERT_EQ("EOF", Read());
+ ASSERT_EQ("EOF", Read()); // Make sure reads at eof work
+}
+
+TEST_F(RecordioTest, ManyRecords) {
+ for (int i = 0; i < 100000; i++) {
+ Write(NumberString(i));
+ }
+ for (int i = 0; i < 100000; i++) {
+ ASSERT_EQ(NumberString(i), Read());
+ }
+ ASSERT_EQ("EOF", Read());
+}
+
+TEST_F(RecordioTest, RandomRead) {
+ const int N = 500;
+ {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int i = 0; i < N; i++) {
+ Write(RandomSkewedString(i, &rnd));
+ }
+ }
+ {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int i = 0; i < N; i++) {
+ ASSERT_EQ(RandomSkewedString(i, &rnd), Read());
+ }
+ }
+ ASSERT_EQ("EOF", Read());
+}
+
+// Tests of all the error paths in log_reader.cc follow:
+static void AssertHasSubstr(StringPiece s, StringPiece expected) {
+ EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain "
+ << expected;
+}
+
+TEST_F(RecordioTest, ReadError) {
+ Write("foo");
+ ForceError();
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptLength) {
+ Write("foo");
+ IncrementByte(6, 100);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptLengthCrc) {
+ Write("foo");
+ IncrementByte(10, 100);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptData) {
+ Write("foo");
+ IncrementByte(14, 10);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptDataCrc) {
+ Write("foo");
+ IncrementByte(WrittenBytes() - 1, 10);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); }
+
+TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
+
+} // namespace io
+} // namespace tensorflow