aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_morphing.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/test/cxx11_tensor_morphing.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_morphing.cpp165
1 files changed, 147 insertions, 18 deletions
diff --git a/unsupported/test/cxx11_tensor_morphing.cpp b/unsupported/test/cxx11_tensor_morphing.cpp
index eb3b891fd..f7de43110 100644
--- a/unsupported/test/cxx11_tensor_morphing.cpp
+++ b/unsupported/test/cxx11_tensor_morphing.cpp
@@ -13,6 +13,7 @@
using Eigen::Tensor;
+template<typename>
static void test_simple_reshape()
{
Tensor<float, 5> tensor1(2,3,1,7,1);
@@ -40,7 +41,7 @@ static void test_simple_reshape()
}
}
-
+template<typename>
static void test_reshape_in_expr() {
MatrixXf m1(2,3*5*7*11);
MatrixXf m2(3*5*7*11,13);
@@ -65,7 +66,7 @@ static void test_reshape_in_expr() {
}
}
-
+template<typename>
static void test_reshape_as_lvalue()
{
Tensor<float, 3> tensor(2,3,7);
@@ -114,6 +115,7 @@ static void test_simple_slice()
}
}
+template<typename=void>
static void test_const_slice()
{
const float b[1] = {42};
@@ -315,6 +317,128 @@ static void test_slice_raw_data()
VERIFY_IS_EQUAL(slice6.data(), tensor.data());
}
+
+template<int DataLayout>
+static void test_strided_slice()
+{
+ typedef Tensor<float, 5, DataLayout> Tensor5f;
+ typedef Eigen::DSizes<Eigen::DenseIndex, 5> Index5;
+ typedef Tensor<float, 2, DataLayout> Tensor2f;
+ typedef Eigen::DSizes<Eigen::DenseIndex, 2> Index2;
+ Tensor<float, 5, DataLayout> tensor(2,3,5,7,11);
+ Tensor<float, 2, DataLayout> tensor2(7,11);
+ tensor.setRandom();
+ tensor2.setRandom();
+
+ if (true) {
+ Tensor2f slice(2,3);
+ Index2 strides(-2,-1);
+ Index2 indicesStart(5,7);
+ Index2 indicesStop(0,4);
+ slice = tensor2.stridedSlice(indicesStart, indicesStop, strides);
+ for (int j = 0; j < 2; ++j) {
+ for (int k = 0; k < 3; ++k) {
+ VERIFY_IS_EQUAL(slice(j,k), tensor2(5-2*j,7-k));
+ }
+ }
+ }
+
+ if(true) {
+ Tensor2f slice(0,1);
+ Index2 strides(1,1);
+ Index2 indicesStart(5,4);
+ Index2 indicesStop(5,5);
+ slice = tensor2.stridedSlice(indicesStart, indicesStop, strides);
+ }
+
+ if(true) { // test clamped degenerate interavls
+ Tensor2f slice(7,11);
+ Index2 strides(1,-1);
+ Index2 indicesStart(-3,20); // should become 0,10
+ Index2 indicesStop(20,-11); // should become 11, -1
+ slice = tensor2.stridedSlice(indicesStart, indicesStop, strides);
+ for (int j = 0; j < 7; ++j) {
+ for (int k = 0; k < 11; ++k) {
+ VERIFY_IS_EQUAL(slice(j,k), tensor2(j,10-k));
+ }
+ }
+ }
+
+ if(true) {
+ Tensor5f slice1(1,1,1,1,1);
+ Eigen::DSizes<Eigen::DenseIndex, 5> indicesStart(1, 2, 3, 4, 5);
+ Eigen::DSizes<Eigen::DenseIndex, 5> indicesStop(2, 3, 4, 5, 6);
+ Eigen::DSizes<Eigen::DenseIndex, 5> strides(1, 1, 1, 1, 1);
+ slice1 = tensor.stridedSlice(indicesStart, indicesStop, strides);
+ VERIFY_IS_EQUAL(slice1(0,0,0,0,0), tensor(1,2,3,4,5));
+ }
+
+ if(true) {
+ Tensor5f slice(1,1,2,2,3);
+ Index5 start(1, 1, 3, 4, 5);
+ Index5 stop(2, 2, 5, 6, 8);
+ Index5 strides(1, 1, 1, 1, 1);
+ slice = tensor.stridedSlice(start, stop, strides);
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 2; ++j) {
+ for (int k = 0; k < 3; ++k) {
+ VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,3+i,4+j,5+k));
+ }
+ }
+ }
+ }
+
+ if(true) {
+ Tensor5f slice(1,1,2,2,3);
+ Index5 strides3(1, 1, -2, 1, -1);
+ Index5 indices3Start(1, 1, 4, 4, 7);
+ Index5 indices3Stop(2, 2, 0, 6, 4);
+ slice = tensor.stridedSlice(indices3Start, indices3Stop, strides3);
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 2; ++j) {
+ for (int k = 0; k < 3; ++k) {
+ VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,4-2*i,4+j,7-k));
+ }
+ }
+ }
+ }
+
+ if(false) { // tests degenerate interval
+ Tensor5f slice(1,1,2,2,3);
+ Index5 strides3(1, 1, 2, 1, 1);
+ Index5 indices3Start(1, 1, 4, 4, 7);
+ Index5 indices3Stop(2, 2, 0, 6, 4);
+ slice = tensor.stridedSlice(indices3Start, indices3Stop, strides3);
+ }
+}
+
+template<int DataLayout>
+static void test_strided_slice_write()
+{
+ typedef Tensor<float, 2, DataLayout> Tensor2f;
+ typedef Eigen::DSizes<Eigen::DenseIndex, 2> Index2;
+
+ Tensor<float, 2, DataLayout> tensor(7,11),tensor2(7,11);
+ tensor.setRandom();
+ tensor2=tensor;
+ Tensor2f slice(2,3);
+
+ slice.setRandom();
+
+ Index2 strides(1,1);
+ Index2 indicesStart(3,4);
+ Index2 indicesStop(5,7);
+ Index2 lengths(2,3);
+
+ tensor.slice(indicesStart,lengths)=slice;
+ tensor2.stridedSlice(indicesStart,indicesStop,strides)=slice;
+
+ for(int i=0;i<7;i++) for(int j=0;j<11;j++){
+ VERIFY_IS_EQUAL(tensor(i,j), tensor2(i,j));
+ }
+}
+
+
template<int DataLayout>
static void test_composition()
{
@@ -337,20 +461,25 @@ static void test_composition()
void test_cxx11_tensor_morphing()
{
- CALL_SUBTEST(test_simple_reshape());
- CALL_SUBTEST(test_reshape_in_expr());
- CALL_SUBTEST(test_reshape_as_lvalue());
-
- CALL_SUBTEST(test_simple_slice<ColMajor>());
- CALL_SUBTEST(test_simple_slice<RowMajor>());
- CALL_SUBTEST(test_const_slice());
- CALL_SUBTEST(test_slice_in_expr<ColMajor>());
- CALL_SUBTEST(test_slice_in_expr<RowMajor>());
- CALL_SUBTEST(test_slice_as_lvalue<ColMajor>());
- CALL_SUBTEST(test_slice_as_lvalue<RowMajor>());
- CALL_SUBTEST(test_slice_raw_data<ColMajor>());
- CALL_SUBTEST(test_slice_raw_data<RowMajor>());
-
- CALL_SUBTEST(test_composition<ColMajor>());
- CALL_SUBTEST(test_composition<RowMajor>());
+ CALL_SUBTEST_1(test_simple_reshape<void>());
+ CALL_SUBTEST_1(test_reshape_in_expr<void>());
+ CALL_SUBTEST_1(test_reshape_as_lvalue<void>());
+
+ CALL_SUBTEST_1(test_simple_slice<ColMajor>());
+ CALL_SUBTEST_1(test_simple_slice<RowMajor>());
+ CALL_SUBTEST_1(test_const_slice());
+ CALL_SUBTEST_2(test_slice_in_expr<ColMajor>());
+ CALL_SUBTEST_3(test_slice_in_expr<RowMajor>());
+ CALL_SUBTEST_4(test_slice_as_lvalue<ColMajor>());
+ CALL_SUBTEST_4(test_slice_as_lvalue<RowMajor>());
+ CALL_SUBTEST_5(test_slice_raw_data<ColMajor>());
+ CALL_SUBTEST_5(test_slice_raw_data<RowMajor>());
+
+ CALL_SUBTEST_6(test_strided_slice_write<ColMajor>());
+ CALL_SUBTEST_6(test_strided_slice<ColMajor>());
+ CALL_SUBTEST_6(test_strided_slice_write<RowMajor>());
+ CALL_SUBTEST_6(test_strided_slice<RowMajor>());
+
+ CALL_SUBTEST_7(test_composition<ColMajor>());
+ CALL_SUBTEST_7(test_composition<RowMajor>());
}