diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-05-27 12:22:25 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-05-27 12:22:25 -0700 |
commit | abc815798b21ce99c28444d7df5b48573de3e237 (patch) | |
tree | 1da04fe66a1560267a1a6ac3d80bcc3632abad81 /unsupported/test/cxx11_tensor_morphing.cpp | |
parent | 570753759291b60b031d9afd285b3e1efe4f63bc (diff) |
Added a new operation to enable more powerful tensorindexing.
Diffstat (limited to 'unsupported/test/cxx11_tensor_morphing.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_morphing.cpp | 130 |
1 files changed, 130 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_morphing.cpp b/unsupported/test/cxx11_tensor_morphing.cpp index eb3b891fd..2465ce9f2 100644 --- a/unsupported/test/cxx11_tensor_morphing.cpp +++ b/unsupported/test/cxx11_tensor_morphing.cpp @@ -315,6 +315,131 @@ static void test_slice_raw_data() VERIFY_IS_EQUAL(slice6.data(), tensor.data()); } + +template<int DataLayout> +static void test_strided_slice() +{ + typedef Tensor<float, 5, DataLayout> Tensor5f; + typedef Eigen::DSizes<Eigen::DenseIndex, 5> Index5; + typedef Tensor<float, 2, DataLayout> Tensor2f; + typedef Eigen::DSizes<Eigen::DenseIndex, 2> Index2; + Tensor<float, 5, DataLayout> tensor(2,3,5,7,11); + tensor.setRandom(); + + if(true) { + Tensor<float, 2, DataLayout> tensor(7,11); + tensor.setRandom(); + Tensor2f slice(2,3); + Index2 strides(-2,-1); + Index2 indicesStart(5,7); + Index2 indicesStop(0,4); + slice = tensor.stridedSlice(indicesStart, indicesStop, strides); + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + VERIFY_IS_EQUAL(slice(j,k), tensor(5-2*j,7-k)); + } + } + } + + if(true) { + Tensor<float, 2, DataLayout> tensor(7,11); + tensor.setRandom(); + Tensor2f slice(0,1); + Index2 strides(1,1); + Index2 indicesStart(5,4); + Index2 indicesStop(5,5); + slice = tensor.stridedSlice(indicesStart, indicesStop, strides); + } + + if(true) { // test clamped degenerate interavls + Tensor<float, 2, DataLayout> tensor(7,11); + tensor.setRandom(); + Tensor2f slice(7,11); + Index2 strides(1,-1); + Index2 indicesStart(-3,20); // should become 0,10 + Index2 indicesStop(20,-11); // should become 11, -1 + slice = tensor.stridedSlice(indicesStart, indicesStop, strides); + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 11; ++k) { + VERIFY_IS_EQUAL(slice(j,k), tensor(j,10-k)); + } + } + } + + if(true) { + Tensor5f slice1(1,1,1,1,1); + Eigen::DSizes<Eigen::DenseIndex, 5> indicesStart(1, 2, 3, 4, 5); + Eigen::DSizes<Eigen::DenseIndex, 5> indicesStop(2, 3, 4, 5, 6); + Eigen::DSizes<Eigen::DenseIndex, 5> strides(1, 1, 1, 1, 1); + slice1 = tensor.stridedSlice(indicesStart, indicesStop, strides); + VERIFY_IS_EQUAL(slice1(0,0,0,0,0), tensor(1,2,3,4,5)); + } + + if(true) { + Tensor5f slice(1,1,2,2,3); + Index5 start(1, 1, 3, 4, 5); + Index5 stop(2, 2, 5, 6, 8); + Index5 strides(1, 1, 1, 1, 1); + slice = tensor.stridedSlice(start, stop, strides); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,3+i,4+j,5+k)); + } + } + } + } + if(true) { + Tensor5f slice(1,1,2,2,3); + Index5 strides3(1, 1, -2, 1, -1); + Index5 indices3Start(1, 1, 4, 4, 7); + Index5 indices3Stop(2, 2, 0, 6, 4); + slice = tensor.stridedSlice(indices3Start, indices3Stop, strides3); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,4-2*i,4+j,7-k)); + } + } + } + } + + if(false) { // tests degenerate interval + Tensor5f slice(1,1,2,2,3); + Index5 strides3(1, 1, 2, 1, 1); + Index5 indices3Start(1, 1, 4, 4, 7); + Index5 indices3Stop(2, 2, 0, 6, 4); + slice = tensor.stridedSlice(indices3Start, indices3Stop, strides3); + } +} + +template<int DataLayout> +static void test_strided_slice_write() +{ + typedef Tensor<float, 2, DataLayout> Tensor2f; + typedef Eigen::DSizes<Eigen::DenseIndex, 2> Index2; + + Tensor<float, 2, DataLayout> tensor(7,11),tensor2(7,11); + tensor.setRandom(); + tensor2=tensor; + Tensor2f slice(2,3); + + slice.setRandom(); + + Index2 strides(1,1); + Index2 indicesStart(3,4); + Index2 indicesStop(5,7); + Index2 lengths(2,3); + + tensor.slice(indicesStart,lengths)=slice; + tensor2.stridedSlice(indicesStart,indicesStop,strides)=slice; + + for(int i=0;i<7;i++) for(int j=0;j<11;j++){ + VERIFY_IS_EQUAL(tensor(i,j), tensor2(i,j)); + } +} + + template<int DataLayout> static void test_composition() { @@ -351,6 +476,11 @@ void test_cxx11_tensor_morphing() CALL_SUBTEST(test_slice_raw_data<ColMajor>()); CALL_SUBTEST(test_slice_raw_data<RowMajor>()); + CALL_SUBTEST(test_strided_slice_write<ColMajor>()); + CALL_SUBTEST(test_strided_slice<ColMajor>()); + CALL_SUBTEST(test_strided_slice_write<RowMajor>()); + CALL_SUBTEST(test_strided_slice<RowMajor>()); + CALL_SUBTEST(test_composition<ColMajor>()); CALL_SUBTEST(test_composition<RowMajor>()); } |