diff options
Diffstat (limited to 'tensorflow/core/util/tensor_slice_util_test.cc')
-rw-r--r-- | tensorflow/core/util/tensor_slice_util_test.cc | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_util_test.cc b/tensorflow/core/util/tensor_slice_util_test.cc new file mode 100644 index 0000000000..348b0c884e --- /dev/null +++ b/tensorflow/core/util/tensor_slice_util_test.cc @@ -0,0 +1,91 @@ +#include "tensorflow/core/util/tensor_slice_util.h" + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +// Testing copying data from one tensor slice to another tensor slice +TEST(TensorSliceUtilTest, CopyTensorSliceToTensorSlice) { + // We map out a 2-d tensor of size 4 X 5 and we want the final results look + // like this: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // 10 11 12 13 14 + // 15 16 17 18 19 + // + // We assume this is a row-major matrix + // + TensorShape shape({4, 5}); + + // We will try to do a couple of slice to slice copies. + + // Case 1: simple identity copy + // The slice is the "interior" of the matrix + // . . . . . + // . 6 7 8 . + // , 11 12 13 . + // . . . . . + { + TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3"); + TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,3"); + const float ptr_s[] = {6, 7, 8, 11, 12, 13}; + float ptr_d[6]; + for (int i = 0; i < 6; ++i) { + ptr_d[i] = 0; + } + EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d, + ptr_s, ptr_d)); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(ptr_s[i], ptr_d[i]); + } + } + + // Case 2: no intersection + { + TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3"); + TensorSlice slice_d = TensorSlice::ParseOrDie("3,1:2,3"); + const float ptr_s[] = {6, 7, 8, 11, 12, 13}; + float ptr_d[6]; + EXPECT_FALSE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d, + ptr_s, ptr_d)); + } + + // Case 3: a trickier case + // The source slice is on the upper left corner: + // 0 1 2 . . + // 5 6 7 . . + // 10 11 12 . . + // . . . . . + // + // The destination slice is the right part of the middle stripe: + // . . . . . + // . X X X X + // . X X X X + // . . . . . + // + // So we expect to copy over the 2X2 block: + // . . . . . + // . 6 7 . . + // . 11 12 . . + // . . . . . + { + TensorSlice slice_s = TensorSlice::ParseOrDie("0,3:0,3"); + TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,4"); + const float ptr_s[] = {0, 1, 2, 5, 6, 7, 10, 11, 12}; + float ptr_d[8]; + for (int i = 0; i < 8; ++i) { + ptr_d[i] = 0; + } + EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d, + ptr_s, ptr_d)); + const float expected[] = {6, 7, 0, 0, 11, 12, 0, 0}; + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(expected[i], ptr_d[i]); + } + } +} + +} // namespace +} // namespace tensorflow |