From abc815798b21ce99c28444d7df5b48573de3e237 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 27 May 2016 12:22:25 -0700 Subject: Added a new operation to enable more powerful tensorindexing. --- unsupported/test/cxx11_tensor_morphing.cpp | 130 +++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) (limited to 'unsupported/test/cxx11_tensor_morphing.cpp') diff --git a/unsupported/test/cxx11_tensor_morphing.cpp b/unsupported/test/cxx11_tensor_morphing.cpp index eb3b891fd..2465ce9f2 100644 --- a/unsupported/test/cxx11_tensor_morphing.cpp +++ b/unsupported/test/cxx11_tensor_morphing.cpp @@ -315,6 +315,131 @@ static void test_slice_raw_data() VERIFY_IS_EQUAL(slice6.data(), tensor.data()); } + +template +static void test_strided_slice() +{ + typedef Tensor Tensor5f; + typedef Eigen::DSizes Index5; + typedef Tensor Tensor2f; + typedef Eigen::DSizes Index2; + Tensor tensor(2,3,5,7,11); + tensor.setRandom(); + + if(true) { + Tensor tensor(7,11); + tensor.setRandom(); + Tensor2f slice(2,3); + Index2 strides(-2,-1); + Index2 indicesStart(5,7); + Index2 indicesStop(0,4); + slice = tensor.stridedSlice(indicesStart, indicesStop, strides); + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + VERIFY_IS_EQUAL(slice(j,k), tensor(5-2*j,7-k)); + } + } + } + + if(true) { + Tensor tensor(7,11); + tensor.setRandom(); + Tensor2f slice(0,1); + Index2 strides(1,1); + Index2 indicesStart(5,4); + Index2 indicesStop(5,5); + slice = tensor.stridedSlice(indicesStart, indicesStop, strides); + } + + if(true) { // test clamped degenerate interavls + Tensor tensor(7,11); + tensor.setRandom(); + Tensor2f slice(7,11); + Index2 strides(1,-1); + Index2 indicesStart(-3,20); // should become 0,10 + Index2 indicesStop(20,-11); // should become 11, -1 + slice = tensor.stridedSlice(indicesStart, indicesStop, strides); + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 11; ++k) { + VERIFY_IS_EQUAL(slice(j,k), tensor(j,10-k)); + } + } + } + + if(true) { + Tensor5f slice1(1,1,1,1,1); + Eigen::DSizes indicesStart(1, 2, 3, 4, 5); + Eigen::DSizes indicesStop(2, 3, 4, 5, 6); + Eigen::DSizes strides(1, 1, 1, 1, 1); + slice1 = tensor.stridedSlice(indicesStart, indicesStop, strides); + VERIFY_IS_EQUAL(slice1(0,0,0,0,0), tensor(1,2,3,4,5)); + } + + if(true) { + Tensor5f slice(1,1,2,2,3); + Index5 start(1, 1, 3, 4, 5); + Index5 stop(2, 2, 5, 6, 8); + Index5 strides(1, 1, 1, 1, 1); + slice = tensor.stridedSlice(start, stop, strides); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,3+i,4+j,5+k)); + } + } + } + } + if(true) { + Tensor5f slice(1,1,2,2,3); + Index5 strides3(1, 1, -2, 1, -1); + Index5 indices3Start(1, 1, 4, 4, 7); + Index5 indices3Stop(2, 2, 0, 6, 4); + slice = tensor.stridedSlice(indices3Start, indices3Stop, strides3); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,4-2*i,4+j,7-k)); + } + } + } + } + + if(false) { // tests degenerate interval + Tensor5f slice(1,1,2,2,3); + Index5 strides3(1, 1, 2, 1, 1); + Index5 indices3Start(1, 1, 4, 4, 7); + Index5 indices3Stop(2, 2, 0, 6, 4); + slice = tensor.stridedSlice(indices3Start, indices3Stop, strides3); + } +} + +template +static void test_strided_slice_write() +{ + typedef Tensor Tensor2f; + typedef Eigen::DSizes Index2; + + Tensor tensor(7,11),tensor2(7,11); + tensor.setRandom(); + tensor2=tensor; + Tensor2f slice(2,3); + + slice.setRandom(); + + Index2 strides(1,1); + Index2 indicesStart(3,4); + Index2 indicesStop(5,7); + Index2 lengths(2,3); + + tensor.slice(indicesStart,lengths)=slice; + tensor2.stridedSlice(indicesStart,indicesStop,strides)=slice; + + for(int i=0;i<7;i++) for(int j=0;j<11;j++){ + VERIFY_IS_EQUAL(tensor(i,j), tensor2(i,j)); + } +} + + template static void test_composition() { @@ -351,6 +476,11 @@ void test_cxx11_tensor_morphing() CALL_SUBTEST(test_slice_raw_data()); CALL_SUBTEST(test_slice_raw_data()); + CALL_SUBTEST(test_strided_slice_write()); + CALL_SUBTEST(test_strided_slice()); + CALL_SUBTEST(test_strided_slice_write()); + CALL_SUBTEST(test_strided_slice()); + CALL_SUBTEST(test_composition()); CALL_SUBTEST(test_composition()); } -- cgit v1.2.3