diff options
author | 2015-11-06 16:27:58 -0800 | |
---|---|---|
committer | 2015-11-06 16:27:58 -0800 | |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/lib/io/recordio_test.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/lib/io/recordio_test.cc')
-rw-r--r-- | tensorflow/core/lib/io/recordio_test.cc | 245 |
1 files changed, 245 insertions, 0 deletions
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc new file mode 100644 index 0000000000..3e9c816443 --- /dev/null +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -0,0 +1,245 @@ +#include "tensorflow/core/lib/io/record_reader.h" +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace io { + +// Construct a string of the specified length made out of the supplied +// partial string. +static string BigString(const string& partial_string, size_t n) { + string result; + while (result.size() < n) { + result.append(partial_string); + } + result.resize(n); + return result; +} + +// Construct a string from a number +static string NumberString(int n) { + char buf[50]; + snprintf(buf, sizeof(buf), "%d.", n); + return string(buf); +} + +// Return a skewed potentially long string +static string RandomSkewedString(int i, random::SimplePhilox* rnd) { + return BigString(NumberString(i), rnd->Skewed(17)); +} + +class RecordioTest : public testing::Test { + private: + class StringDest : public WritableFile { + public: + string contents_; + + Status Close() override { return Status::OK(); } + Status Flush() override { return Status::OK(); } + Status Sync() override { return Status::OK(); } + Status Append(const StringPiece& slice) override { + contents_.append(slice.data(), slice.size()); + return Status::OK(); + } + }; + + class StringSource : public RandomAccessFile { + public: + StringPiece contents_; + mutable bool force_error_; + mutable bool returned_partial_; + StringSource() : force_error_(false), returned_partial_(false) {} + + Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { + EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error"; + + if (force_error_) { + force_error_ = false; + returned_partial_ = true; + return errors::DataLoss("read error"); + } + + if (offset >= contents_.size()) { + return errors::OutOfRange("end of file"); + } + + if (contents_.size() < offset + n) { + n = contents_.size() - offset; + returned_partial_ = true; + } + *result = StringPiece(contents_.data() + offset, n); + return Status::OK(); + } + }; + + StringDest dest_; + StringSource source_; + bool reading_; + uint64 readpos_; + RecordWriter* writer_; + RecordReader* reader_; + + public: + RecordioTest() + : reading_(false), + readpos_(0), + writer_(new RecordWriter(&dest_)), + reader_(new RecordReader(&source_)) {} + + ~RecordioTest() override { + delete writer_; + delete reader_; + } + + void Write(const string& msg) { + ASSERT_TRUE(!reading_) << "Write() after starting to read"; + ASSERT_OK(writer_->WriteRecord(StringPiece(msg))); + } + + size_t WrittenBytes() const { return dest_.contents_.size(); } + + string Read() { + if (!reading_) { + reading_ = true; + source_.contents_ = StringPiece(dest_.contents_); + } + string record; + Status s = reader_->ReadRecord(&readpos_, &record); + if (s.ok()) { + return record; + } else if (errors::IsOutOfRange(s)) { + return "EOF"; + } else { + return s.ToString(); + } + } + + void IncrementByte(int offset, int delta) { + dest_.contents_[offset] += delta; + } + + void SetByte(int offset, char new_byte) { + dest_.contents_[offset] = new_byte; + } + + void ShrinkSize(int bytes) { + dest_.contents_.resize(dest_.contents_.size() - bytes); + } + + void FixChecksum(int header_offset, int len) { + // Compute crc of type/len/data + uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len); + crc = crc32c::Mask(crc); + core::EncodeFixed32(&dest_.contents_[header_offset], crc); + } + + void ForceError() { source_.force_error_ = true; } + + void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; } + + void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) { + Write("foo"); + Write("bar"); + Write(BigString("x", 10000)); + reading_ = true; + source_.contents_ = StringPiece(dest_.contents_); + uint64 offset = WrittenBytes() + offset_past_end; + string record; + Status s = reader_->ReadRecord(&offset, &record); + ASSERT_TRUE(errors::IsOutOfRange(s)) << s; + } +}; + +TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); } + +TEST_F(RecordioTest, ReadWrite) { + Write("foo"); + Write("bar"); + Write(""); + Write("xxxx"); + ASSERT_EQ("foo", Read()); + ASSERT_EQ("bar", Read()); + ASSERT_EQ("", Read()); + ASSERT_EQ("xxxx", Read()); + ASSERT_EQ("EOF", Read()); + ASSERT_EQ("EOF", Read()); // Make sure reads at eof work +} + +TEST_F(RecordioTest, ManyRecords) { + for (int i = 0; i < 100000; i++) { + Write(NumberString(i)); + } + for (int i = 0; i < 100000; i++) { + ASSERT_EQ(NumberString(i), Read()); + } + ASSERT_EQ("EOF", Read()); +} + +TEST_F(RecordioTest, RandomRead) { + const int N = 500; + { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int i = 0; i < N; i++) { + Write(RandomSkewedString(i, &rnd)); + } + } + { + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int i = 0; i < N; i++) { + ASSERT_EQ(RandomSkewedString(i, &rnd), Read()); + } + } + ASSERT_EQ("EOF", Read()); +} + +// Tests of all the error paths in log_reader.cc follow: +static void AssertHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " + << expected; +} + +TEST_F(RecordioTest, ReadError) { + Write("foo"); + ForceError(); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptLength) { + Write("foo"); + IncrementByte(6, 100); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptLengthCrc) { + Write("foo"); + IncrementByte(10, 100); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptData) { + Write("foo"); + IncrementByte(14, 10); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, CorruptDataCrc) { + Write("foo"); + IncrementByte(WrittenBytes() - 1, 10); + AssertHasSubstr(Read(), "Data loss"); +} + +TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); } + +TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); } + +} // namespace io +} // namespace tensorflow |