diff options
Diffstat (limited to 'tensorflow/core/util/tensor_slice_reader_test.cc')
-rw-r--r-- | tensorflow/core/util/tensor_slice_reader_test.cc | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc new file mode 100644 index 0000000000..e14b920003 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader_test.cc @@ -0,0 +1,395 @@ +#include "tensorflow/core/util/tensor_slice_reader.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_slice_writer.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +// A simple test where we write a few tensor slices with a number of tensor +// slice writers and then read them back from a tensor slice reader. +// +// We have a 2-d tensor of shape 4 X 5 that looks like this: +// +// 0 1 2 3 4 +// 5 6 7 8 9 +// 10 11 12 13 14 +// 15 16 17 18 19 +// +// We assume this is a row-major matrix. + +void SimpleFloatHelper(TensorSliceWriter::CreateBuilderFunction create_function, + TensorSliceReader::OpenTableFunction open_function) { + const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); + + TensorShape shape({4, 5}); + + // File #0 contains a slice that is the top two rows: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + const string fname = strings::StrCat(fname_base, "_0"); + TensorSliceWriter writer(fname, create_function); + const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + TF_CHECK_OK(writer.Finish()); + } + + // File #1 contains two slices: + // + // slice #0 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + // + // slice #1 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + { + const string fname = strings::StrCat(fname_base, "_1"); + TensorSliceWriter writer(fname, create_function); + // slice #0 + { + const float data[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + // slice #1 + { + const float data[] = {18, 19}; + TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + TF_CHECK_OK(writer.Finish()); + } + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we need to read the tensor slices + const string filepattern = strings::StrCat(fname_base, "_*"); + TensorSliceReader reader(filepattern, open_function); + EXPECT_OK(reader.status()); + EXPECT_EQ(2, reader.num_files()); + + // We query some of the tensors + { + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("test", &shape, &type)); + EXPECT_EQ( + "dim { size: 4 } " + "dim { size: 5 }", + shape.DebugString()); + EXPECT_EQ(DT_FLOAT, type); + EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr)); + } + + // Now we query some slices + // + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + float results[10]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + float expected[] = {5, 6, 7, 8, 9}; + float results[5]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + float results[6]; + EXPECT_FALSE(reader.CopySliceData("test", s, results)); + } +} + +TEST(TensorSliceReaderTest, SimpleFloat) { + SimpleFloatHelper(CreateTableTensorSliceBuilder, OpenTableTensorSliceReader); +} + +template <typename T, typename U> +void SimpleIntXHelper(TensorSliceWriter::CreateBuilderFunction create_function, + TensorSliceReader::OpenTableFunction open_function, + const string& checkpoint_file) { + const string fname_base = io::JoinPath(testing::TmpDir(), checkpoint_file); + + TensorShape shape({4, 5}); + + // File #0 contains a slice that is the top two rows: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + const string fname = strings::StrCat(fname_base, "_0"); + TensorSliceWriter writer(fname, create_function); + const T data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + TF_CHECK_OK(writer.Finish()); + } + + // File #1 contains two slices: + // + // slice #0 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + // + // slice #1 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + { + const string fname = strings::StrCat(fname_base, "_1"); + TensorSliceWriter writer(fname, create_function); + // slice #0 + { + const T data[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + // slice #1 + { + const T data[] = {18, 19}; + TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + TF_CHECK_OK(writer.Finish()); + } + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we need to read the tensor slices + const string filepattern = strings::StrCat(fname_base, "_*"); + TensorSliceReader reader(filepattern, open_function); + EXPECT_OK(reader.status()); + EXPECT_EQ(2, reader.num_files()); + + // We query some of the tensors + { + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("test", &shape, &type)); + EXPECT_EQ( + "dim { size: 4 } " + "dim { size: 5 }", + shape.DebugString()); + EXPECT_EQ(DataTypeToEnum<T>::v(), type); + EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr)); + } + + // Now we query some slices + // + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + T expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + U results[10]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + T expected[] = {5, 6, 7, 8, 9}; + U results[5]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + U results[6]; + EXPECT_FALSE(reader.CopySliceData("test", s, results)); + } +} + +#define TEST_SIMPLE_INT(TYPE, SAVED_TYPE) \ + TEST(TensorSliceReaderTest, Simple##TYPE) { \ + SimpleIntXHelper<TYPE, SAVED_TYPE>(CreateTableTensorSliceBuilder, \ + OpenTableTensorSliceReader, \ + #TYPE "_checkpoint"); \ + } + +TEST_SIMPLE_INT(int32, int32) +TEST_SIMPLE_INT(int64, int64) +TEST_SIMPLE_INT(int16, int32) +TEST_SIMPLE_INT(int8, int32) +TEST_SIMPLE_INT(uint8, int32) + +void CachedTensorSliceReaderTesterHelper( + TensorSliceWriter::CreateBuilderFunction create_function, + TensorSliceReader::OpenTableFunction open_function) { + const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); + + TensorShape shape({4, 5}); + + // File #0 contains a slice that is the top two rows: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + const string fname = strings::StrCat(fname_base, "_0"); + TensorSliceWriter writer(fname, create_function); + const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + TF_CHECK_OK(writer.Finish()); + } + + // File #1 contains two slices: + // + // slice #0 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + // + // slice #1 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + { + const string fname = strings::StrCat(fname_base, "_1"); + TensorSliceWriter writer(fname, create_function); + // slice #0 + { + const float data[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + // slice #1 + { + const float data[] = {18, 19}; + TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + TF_CHECK_OK(writer.Finish()); + } + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we need to read the tensor slices + TensorSliceReaderCache cache; + const string filepattern = strings::StrCat(fname_base, "_*"); + const TensorSliceReader* reader = cache.GetReader( + filepattern, open_function, TensorSliceReader::kLoadAllShards); + EXPECT_TRUE(reader != nullptr); + EXPECT_EQ(2, reader->num_files()); + + // We query some of the tensors + { + TensorShape shape; + DataType type; + EXPECT_TRUE(reader->HasTensor("test", &shape, &type)); + EXPECT_EQ( + "dim { size: 4 } " + "dim { size: 5 }", + shape.DebugString()); + EXPECT_EQ(DT_FLOAT, type); + EXPECT_FALSE(reader->HasTensor("don't exist", nullptr, nullptr)); + } + + // Make sure the reader is cached. + const TensorSliceReader* reader2 = cache.GetReader( + filepattern, open_function, TensorSliceReader::kLoadAllShards); + EXPECT_EQ(reader, reader2); + + reader = cache.GetReader("file_does_not_exist", open_function, + TensorSliceReader::kLoadAllShards); + EXPECT_TRUE(reader == nullptr); +} + +TEST(CachedTensorSliceReaderTest, SimpleFloat) { + CachedTensorSliceReaderTesterHelper(CreateTableTensorSliceBuilder, + OpenTableTensorSliceReader); +} + +} // namespace + +} // namespace checkpoint + +} // namespace tensorflow |