aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_morphing.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-27 12:22:25 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-27 12:22:25 -0700
commitabc815798b21ce99c28444d7df5b48573de3e237 (patch)
tree1da04fe66a1560267a1a6ac3d80bcc3632abad81 /unsupported/test/cxx11_tensor_morphing.cpp
parent570753759291b60b031d9afd285b3e1efe4f63bc (diff)
Added a new operation to enable more powerful tensorindexing.
Diffstat (limited to 'unsupported/test/cxx11_tensor_morphing.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_morphing.cpp130
1 files changed, 130 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_morphing.cpp b/unsupported/test/cxx11_tensor_morphing.cpp
index eb3b891fd..2465ce9f2 100644
--- a/unsupported/test/cxx11_tensor_morphing.cpp
+++ b/unsupported/test/cxx11_tensor_morphing.cpp
@@ -315,6 +315,131 @@ static void test_slice_raw_data()
VERIFY_IS_EQUAL(slice6.data(), tensor.data());
}
+
+template<int DataLayout>
+static void test_strided_slice()
+{
+ typedef Tensor<float, 5, DataLayout> Tensor5f;
+ typedef Eigen::DSizes<Eigen::DenseIndex, 5> Index5;
+ typedef Tensor<float, 2, DataLayout> Tensor2f;
+ typedef Eigen::DSizes<Eigen::DenseIndex, 2> Index2;
+ Tensor<float, 5, DataLayout> tensor(2,3,5,7,11);
+ tensor.setRandom();
+
+ if(true) {
+ Tensor<float, 2, DataLayout> tensor(7,11);
+ tensor.setRandom();
+ Tensor2f slice(2,3);
+ Index2 strides(-2,-1);
+ Index2 indicesStart(5,7);
+ Index2 indicesStop(0,4);
+ slice = tensor.stridedSlice(indicesStart, indicesStop, strides);
+ for (int j = 0; j < 2; ++j) {
+ for (int k = 0; k < 3; ++k) {
+ VERIFY_IS_EQUAL(slice(j,k), tensor(5-2*j,7-k));
+ }
+ }
+ }
+
+ if(true) {
+ Tensor<float, 2, DataLayout> tensor(7,11);
+ tensor.setRandom();
+ Tensor2f slice(0,1);
+ Index2 strides(1,1);
+ Index2 indicesStart(5,4);
+ Index2 indicesStop(5,5);
+ slice = tensor.stridedSlice(indicesStart, indicesStop, strides);
+ }
+
+ if(true) { // test clamped degenerate interavls
+ Tensor<float, 2, DataLayout> tensor(7,11);
+ tensor.setRandom();
+ Tensor2f slice(7,11);
+ Index2 strides(1,-1);
+ Index2 indicesStart(-3,20); // should become 0,10
+ Index2 indicesStop(20,-11); // should become 11, -1
+ slice = tensor.stridedSlice(indicesStart, indicesStop, strides);
+ for (int j = 0; j < 7; ++j) {
+ for (int k = 0; k < 11; ++k) {
+ VERIFY_IS_EQUAL(slice(j,k), tensor(j,10-k));
+ }
+ }
+ }
+
+ if(true) {
+ Tensor5f slice1(1,1,1,1,1);
+ Eigen::DSizes<Eigen::DenseIndex, 5> indicesStart(1, 2, 3, 4, 5);
+ Eigen::DSizes<Eigen::DenseIndex, 5> indicesStop(2, 3, 4, 5, 6);
+ Eigen::DSizes<Eigen::DenseIndex, 5> strides(1, 1, 1, 1, 1);
+ slice1 = tensor.stridedSlice(indicesStart, indicesStop, strides);
+ VERIFY_IS_EQUAL(slice1(0,0,0,0,0), tensor(1,2,3,4,5));
+ }
+
+ if(true) {
+ Tensor5f slice(1,1,2,2,3);
+ Index5 start(1, 1, 3, 4, 5);
+ Index5 stop(2, 2, 5, 6, 8);
+ Index5 strides(1, 1, 1, 1, 1);
+ slice = tensor.stridedSlice(start, stop, strides);
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 2; ++j) {
+ for (int k = 0; k < 3; ++k) {
+ VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,3+i,4+j,5+k));
+ }
+ }
+ }
+ }
+ if(true) {
+ Tensor5f slice(1,1,2,2,3);
+ Index5 strides3(1, 1, -2, 1, -1);
+ Index5 indices3Start(1, 1, 4, 4, 7);
+ Index5 indices3Stop(2, 2, 0, 6, 4);
+ slice = tensor.stridedSlice(indices3Start, indices3Stop, strides3);
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 2; ++j) {
+ for (int k = 0; k < 3; ++k) {
+ VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,4-2*i,4+j,7-k));
+ }
+ }
+ }
+ }
+
+ if(false) { // tests degenerate interval
+ Tensor5f slice(1,1,2,2,3);
+ Index5 strides3(1, 1, 2, 1, 1);
+ Index5 indices3Start(1, 1, 4, 4, 7);
+ Index5 indices3Stop(2, 2, 0, 6, 4);
+ slice = tensor.stridedSlice(indices3Start, indices3Stop, strides3);
+ }
+}
+
+template<int DataLayout>
+static void test_strided_slice_write()
+{
+ typedef Tensor<float, 2, DataLayout> Tensor2f;
+ typedef Eigen::DSizes<Eigen::DenseIndex, 2> Index2;
+
+ Tensor<float, 2, DataLayout> tensor(7,11),tensor2(7,11);
+ tensor.setRandom();
+ tensor2=tensor;
+ Tensor2f slice(2,3);
+
+ slice.setRandom();
+
+ Index2 strides(1,1);
+ Index2 indicesStart(3,4);
+ Index2 indicesStop(5,7);
+ Index2 lengths(2,3);
+
+ tensor.slice(indicesStart,lengths)=slice;
+ tensor2.stridedSlice(indicesStart,indicesStop,strides)=slice;
+
+ for(int i=0;i<7;i++) for(int j=0;j<11;j++){
+ VERIFY_IS_EQUAL(tensor(i,j), tensor2(i,j));
+ }
+}
+
+
template<int DataLayout>
static void test_composition()
{
@@ -351,6 +476,11 @@ void test_cxx11_tensor_morphing()
CALL_SUBTEST(test_slice_raw_data<ColMajor>());
CALL_SUBTEST(test_slice_raw_data<RowMajor>());
+ CALL_SUBTEST(test_strided_slice_write<ColMajor>());
+ CALL_SUBTEST(test_strided_slice<ColMajor>());
+ CALL_SUBTEST(test_strided_slice_write<RowMajor>());
+ CALL_SUBTEST(test_strided_slice<RowMajor>());
+
CALL_SUBTEST(test_composition<ColMajor>());
CALL_SUBTEST(test_composition<RowMajor>());
}