aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-19 13:35:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-19 13:41:23 -0700
commit58fdd0dfce6d4c71fa7d381190987ccad33da0b6 (patch)
tree9b541b583535135bd03b46c4279992fda6580c4f /tensorflow/core/lib
parent661ad6be85fa611fa297bc8b8bacef752bef7ffc (diff)
Test io::RecordWriter.flush()
Requires changing flush_mode from default Z_NO_FLUSH See tensorflow/core/lib/io/zlib_compression_options.h PiperOrigin-RevId: 205293231
Diffstat (limited to 'tensorflow/core/lib')
-rw-r--r--tensorflow/core/lib/io/record_reader_writer_test.cc84
1 files changed, 84 insertions, 0 deletions
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc
index 95ac040602..c36c909399 100644
--- a/tensorflow/core/lib/io/record_reader_writer_test.cc
+++ b/tensorflow/core/lib/io/record_reader_writer_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/record_writer.h"
+#include <zlib.h>
#include <vector>
#include "tensorflow/core/platform/env.h"
@@ -33,6 +34,89 @@ static std::vector<int> BufferSizes() {
12, 13, 14, 15, 16, 17, 18, 19, 20, 65536};
}
+namespace {
+
+io::RecordReaderOptions GetMatchingReaderOptions(
+ const io::RecordWriterOptions& options) {
+ if (options.compression_type == io::RecordWriterOptions::ZLIB_COMPRESSION) {
+ return io::RecordReaderOptions::CreateRecordReaderOptions("ZLIB");
+ }
+ return io::RecordReaderOptions::CreateRecordReaderOptions("");
+}
+
+uint64 GetFileSize(const string& fname) {
+ Env* env = Env::Default();
+ uint64 fsize;
+ TF_CHECK_OK(env->GetFileSize(fname, &fsize));
+ return fsize;
+}
+
+void VerifyFlush(const io::RecordWriterOptions& options) {
+ std::vector<string> records = {
+ "abcdefghijklmnopqrstuvwxyz",
+ "ZYXWVUTSRQPONMLKJIHGFEDCBA0123456789!@#$%^&*()",
+ "G5SyohOL9UmXofSOOwWDrv9hoLLMYPJbG9r38t3uBRcHxHj2PdKcPDuZmKW62RIY",
+ "aaaaaaaaaaaaaaaaaaaaaaaaaa",
+ };
+
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/record_reader_writer_flush_test";
+
+ std::unique_ptr<WritableFile> file;
+ TF_CHECK_OK(env->NewWritableFile(fname, &file));
+ io::RecordWriter writer(file.get(), options);
+
+ std::unique_ptr<RandomAccessFile> read_file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &read_file));
+ io::RecordReaderOptions read_options = GetMatchingReaderOptions(options);
+ io::RecordReader reader(read_file.get(), read_options);
+
+ EXPECT_EQ(GetFileSize(fname), 0);
+ for (size_t i = 0; i < records.size(); i++) {
+ uint64 start_size = GetFileSize(fname);
+
+ // Write a new record.
+ TF_EXPECT_OK(writer.WriteRecord(records[i]));
+ TF_CHECK_OK(writer.Flush());
+ TF_CHECK_OK(file->Flush());
+
+ // Verify that file size has changed after file flush.
+ uint64 new_size = GetFileSize(fname);
+ EXPECT_GT(new_size, start_size);
+
+ // Verify that file has all records written so far and no more.
+ uint64 offset = 0;
+ string record;
+ for (size_t j = 0; j <= i; j++) {
+ // Check that j'th record is written correctly.
+ TF_CHECK_OK(reader.ReadRecord(&offset, &record));
+ EXPECT_EQ(record, records[j]);
+ }
+
+ // Verify that file has no more records.
+ CHECK_EQ(reader.ReadRecord(&offset, &record).code(), error::OUT_OF_RANGE);
+ }
+}
+
+} // namespace
+
+TEST(RecordReaderWriterTest, TestFlush) {
+ io::RecordWriterOptions options;
+ VerifyFlush(options);
+}
+
+TEST(RecordReaderWriterTest, TestZlibSyncFlush) {
+ io::RecordWriterOptions options;
+ options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION;
+ // The default flush_mode is Z_NO_FLUSH and only writes to the file when the
+ // buffer is full or the file is closed, which makes testing harder.
+ // By using Z_SYNC_FLUSH the test can verify Flush does write out records of
+ // approximately the right size at the right times.
+ options.zlib_options.flush_mode = Z_SYNC_FLUSH;
+
+ VerifyFlush(options);
+}
+
TEST(RecordReaderWriterTest, TestBasics) {
Env* env = Env::Default();
string fname = testing::TmpDir() + "/record_reader_writer_test";