aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_slice_set_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/tensor_slice_set_test.cc')
-rw-r--r--tensorflow/core/util/tensor_slice_set_test.cc227
1 files changed, 227 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_set_test.cc b/tensorflow/core/util/tensor_slice_set_test.cc
new file mode 100644
index 0000000000..fb2f46f34c
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_set_test.cc
@@ -0,0 +1,227 @@
+#include "tensorflow/core/util/tensor_slice_set.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+namespace {
+
+// A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this:
+//
+// 0 1 2 3 4
+// 5 6 7 8 9
+// 10 11 12 13 14
+// 15 16 17 18 19
+//
+// We assume this is a row-major matrix.
+//
+// We store the tensor in a couple of slices and verify that we can recover all
+// of them.
+TEST(TensorSliceSetTest, QueryTwoD) {
+ TensorShape shape({4, 5});
+
+ TensorSliceSet tss(shape, DT_FLOAT);
+ // We store a few slices.
+
+ // Slice #1 is the top two rows:
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(tss.Register(slice_1, "", src_1));
+
+ // Slice #2 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ const float src_2[] = {10, 11, 12, 15, 16, 17};
+ TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(tss.Register(slice_2, "", src_2));
+
+ // Slice #3 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ const float src_3[] = {18, 19};
+ TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(tss.Register(slice_3, "", src_3));
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we query some of the slices
+
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ float results[10];
+ EXPECT_TRUE(tss.Query(s, results));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ float expected[] = {5, 6, 7, 8, 9};
+ float results[5];
+ EXPECT_TRUE(tss.Query(s, results));
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #3 is a more complicated match: it needs the combination of a couple
+ // of slices
+ // . . . . .
+ // 5 6 7 . .
+ // 10 11 12 . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
+ float expected[] = {5, 6, 7, 10, 11, 12};
+ float results[6];
+ EXPECT_TRUE(tss.Query(s, results));
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ float results[6];
+ EXPECT_FALSE(tss.Query(s, results));
+ }
+}
+
+// Testing the meta version of the tensor slice set.
+TEST(TensorSliceSetTest, QueryMetaTwoD) {
+ TensorShape shape({4, 5});
+
+ TensorSliceSet tss(shape, DT_INT32);
+ // We store a few slices.
+
+ // Slice #1 is the top two rows:
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr));
+
+ // Slice #2 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr));
+
+ // Slice #3 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr));
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we query some of the slices
+
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ // We just need slice_1 for this
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_TRUE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(1, results.size());
+ EXPECT_EQ("0,2:-", results[0].first.DebugString());
+ EXPECT_EQ("slice_1", results[0].second);
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ // We just need slice_1 for this
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_TRUE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(1, results.size());
+ EXPECT_EQ("0,2:-", results[0].first.DebugString());
+ EXPECT_EQ("slice_1", results[0].second);
+ }
+
+ // Slice #3 is a more complicated match: it needs the combination of a couple
+ // of slices
+ // . . . . .
+ // 5 6 7 . .
+ // 10 11 12 . .
+ // . . . . .
+ // We need both slice_1 and slice_2 for this.
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_TRUE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(2, results.size());
+ EXPECT_EQ("2,2:0,3", results[0].first.DebugString());
+ EXPECT_EQ("slice_2", results[0].second);
+ EXPECT_EQ("0,2:-", results[1].first.DebugString());
+ EXPECT_EQ("slice_1", results[1].second);
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_FALSE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(0, results.size());
+ }
+}
+
+} // namespace
+
+} // namespace checkpoint
+
+} // namespace tensorflow