#include "tensorflow/core/util/tensor_slice_set.h" #include "tensorflow/core/platform/logging.h" #include #include "tensorflow/core/public/status.h" namespace tensorflow { namespace checkpoint { namespace { // A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this: // // 0 1 2 3 4 // 5 6 7 8 9 // 10 11 12 13 14 // 15 16 17 18 19 // // We assume this is a row-major matrix. // // We store the tensor in a couple of slices and verify that we can recover all // of them. TEST(TensorSliceSetTest, QueryTwoD) { TensorShape shape({4, 5}); TensorSliceSet tss(shape, DT_FLOAT); // We store a few slices. // Slice #1 is the top two rows: // 0 1 2 3 4 // 5 6 7 8 9 // . . . . . // . . . . . const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); TF_CHECK_OK(tss.Register(slice_1, "", src_1)); // Slice #2 is the bottom left corner // . . . . . // . . . . . // 10 11 12 . . // 15 16 17 . . const float src_2[] = {10, 11, 12, 15, 16, 17}; TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); TF_CHECK_OK(tss.Register(slice_2, "", src_2)); // Slice #3 is the bottom right corner // . . . . . // . . . . . // . . . . . // . . . 18 19 const float src_3[] = {18, 19}; TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); TF_CHECK_OK(tss.Register(slice_3, "", src_3)); // Notice that we leave a hole in the tensor // . . . . . // . . . . . // . . . (13) (14) // . . . . . // Now we query some of the slices // Slice #1 is an exact match // 0 1 2 3 4 // 5 6 7 8 9 // . . . . . // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; float results[10]; EXPECT_TRUE(tss.Query(s, results)); for (int i = 0; i < 10; ++i) { EXPECT_EQ(expected[i], results[i]); } } // Slice #2 is a subset match // . . . . . // 5 6 7 8 9 // . . . . . // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); float expected[] = {5, 6, 7, 8, 9}; float results[5]; EXPECT_TRUE(tss.Query(s, results)); for (int i = 0; i < 5; ++i) { EXPECT_EQ(expected[i], results[i]); } } // Slice #3 is a more complicated match: it needs the combination of a couple // of slices // . . . . . // 5 6 7 . . // 10 11 12 . . // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); float expected[] = {5, 6, 7, 10, 11, 12}; float results[6]; EXPECT_TRUE(tss.Query(s, results)); for (int i = 0; i < 6; ++i) { EXPECT_EQ(expected[i], results[i]); } } // Slice #4 includes the hole and so there is no match // . . . . . // . . 7 8 9 // . . 12 13 14 // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); float results[6]; EXPECT_FALSE(tss.Query(s, results)); } } // Testing the meta version of the tensor slice set. TEST(TensorSliceSetTest, QueryMetaTwoD) { TensorShape shape({4, 5}); TensorSliceSet tss(shape, DT_INT32); // We store a few slices. // Slice #1 is the top two rows: // 0 1 2 3 4 // 5 6 7 8 9 // . . . . . // . . . . . TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr)); // Slice #2 is the bottom left corner // . . . . . // . . . . . // 10 11 12 . . // 15 16 17 . . TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr)); // Slice #3 is the bottom right corner // . . . . . // . . . . . // . . . . . // . . . 18 19 TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr)); // Notice that we leave a hole in the tensor // . . . . . // . . . . . // . . . (13) (14) // . . . . . // Now we query some of the slices // Slice #1 is an exact match // 0 1 2 3 4 // 5 6 7 8 9 // . . . . . // . . . . . // We just need slice_1 for this { TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); std::vector> results; EXPECT_TRUE(tss.QueryMeta(s, &results)); EXPECT_EQ(1, results.size()); EXPECT_EQ("0,2:-", results[0].first.DebugString()); EXPECT_EQ("slice_1", results[0].second); } // Slice #2 is a subset match // . . . . . // 5 6 7 8 9 // . . . . . // . . . . . // We just need slice_1 for this { TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); std::vector> results; EXPECT_TRUE(tss.QueryMeta(s, &results)); EXPECT_EQ(1, results.size()); EXPECT_EQ("0,2:-", results[0].first.DebugString()); EXPECT_EQ("slice_1", results[0].second); } // Slice #3 is a more complicated match: it needs the combination of a couple // of slices // . . . . . // 5 6 7 . . // 10 11 12 . . // . . . . . // We need both slice_1 and slice_2 for this. { TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); std::vector> results; EXPECT_TRUE(tss.QueryMeta(s, &results)); EXPECT_EQ(2, results.size()); EXPECT_EQ("2,2:0,3", results[0].first.DebugString()); EXPECT_EQ("slice_2", results[0].second); EXPECT_EQ("0,2:-", results[1].first.DebugString()); EXPECT_EQ("slice_1", results[1].second); } // Slice #4 includes the hole and so there is no match // . . . . . // . . 7 8 9 // . . 12 13 14 // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); std::vector> results; EXPECT_FALSE(tss.QueryMeta(s, &results)); EXPECT_EQ(0, results.size()); } } } // namespace } // namespace checkpoint } // namespace tensorflow