diff options
Diffstat (limited to 'tensorflow/core/util/tensor_slice_set_test.cc')
-rw-r--r-- | tensorflow/core/util/tensor_slice_set_test.cc | 227 |
1 files changed, 227 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_set_test.cc b/tensorflow/core/util/tensor_slice_set_test.cc new file mode 100644 index 0000000000..fb2f46f34c --- /dev/null +++ b/tensorflow/core/util/tensor_slice_set_test.cc @@ -0,0 +1,227 @@ +#include "tensorflow/core/util/tensor_slice_set.h" + +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +// A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this: +// +// 0 1 2 3 4 +// 5 6 7 8 9 +// 10 11 12 13 14 +// 15 16 17 18 19 +// +// We assume this is a row-major matrix. +// +// We store the tensor in a couple of slices and verify that we can recover all +// of them. +TEST(TensorSliceSetTest, QueryTwoD) { + TensorShape shape({4, 5}); + + TensorSliceSet tss(shape, DT_FLOAT); + // We store a few slices. + + // Slice #1 is the top two rows: + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(tss.Register(slice_1, "", src_1)); + + // Slice #2 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + const float src_2[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(tss.Register(slice_2, "", src_2)); + + // Slice #3 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + const float src_3[] = {18, 19}; + TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(tss.Register(slice_3, "", src_3)); + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we query some of the slices + + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + float results[10]; + EXPECT_TRUE(tss.Query(s, results)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + float expected[] = {5, 6, 7, 8, 9}; + float results[5]; + EXPECT_TRUE(tss.Query(s, results)); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #3 is a more complicated match: it needs the combination of a couple + // of slices + // . . . . . + // 5 6 7 . . + // 10 11 12 . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); + float expected[] = {5, 6, 7, 10, 11, 12}; + float results[6]; + EXPECT_TRUE(tss.Query(s, results)); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + float results[6]; + EXPECT_FALSE(tss.Query(s, results)); + } +} + +// Testing the meta version of the tensor slice set. +TEST(TensorSliceSetTest, QueryMetaTwoD) { + TensorShape shape({4, 5}); + + TensorSliceSet tss(shape, DT_INT32); + // We store a few slices. + + // Slice #1 is the top two rows: + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr)); + + // Slice #2 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr)); + + // Slice #3 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr)); + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we query some of the slices + + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + // We just need slice_1 for this + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + std::vector<std::pair<TensorSlice, string>> results; + EXPECT_TRUE(tss.QueryMeta(s, &results)); + EXPECT_EQ(1, results.size()); + EXPECT_EQ("0,2:-", results[0].first.DebugString()); + EXPECT_EQ("slice_1", results[0].second); + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + // We just need slice_1 for this + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + std::vector<std::pair<TensorSlice, string>> results; + EXPECT_TRUE(tss.QueryMeta(s, &results)); + EXPECT_EQ(1, results.size()); + EXPECT_EQ("0,2:-", results[0].first.DebugString()); + EXPECT_EQ("slice_1", results[0].second); + } + + // Slice #3 is a more complicated match: it needs the combination of a couple + // of slices + // . . . . . + // 5 6 7 . . + // 10 11 12 . . + // . . . . . + // We need both slice_1 and slice_2 for this. + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); + std::vector<std::pair<TensorSlice, string>> results; + EXPECT_TRUE(tss.QueryMeta(s, &results)); + EXPECT_EQ(2, results.size()); + EXPECT_EQ("2,2:0,3", results[0].first.DebugString()); + EXPECT_EQ("slice_2", results[0].second); + EXPECT_EQ("0,2:-", results[1].first.DebugString()); + EXPECT_EQ("slice_1", results[1].second); + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + std::vector<std::pair<TensorSlice, string>> results; + EXPECT_FALSE(tss.QueryMeta(s, &results)); + EXPECT_EQ(0, results.size()); + } +} + +} // namespace + +} // namespace checkpoint + +} // namespace tensorflow |