From 62b710072e282ad70bbcb38468367f7f99232d32 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 4 May 2016 21:08:22 -0700 Subject: Reduced the memory footprint of the cxx11_tensor_image_patch test --- unsupported/test/cxx11_tensor_image_patch.cpp | 49 +++++++-------------------- 1 file changed, 12 insertions(+), 37 deletions(-) (limited to 'unsupported/test/cxx11_tensor_image_patch.cpp') diff --git a/unsupported/test/cxx11_tensor_image_patch.cpp b/unsupported/test/cxx11_tensor_image_patch.cpp index 5d6a49181..988b01481 100644 --- a/unsupported/test/cxx11_tensor_image_patch.cpp +++ b/unsupported/test/cxx11_tensor_image_patch.cpp @@ -568,13 +568,7 @@ static void test_imagenet_patches() VERIFY_IS_EQUAL(l_out.dimension(4), 16); // RowMajor - Tensor l_in_row_major = l_in.swap_layout(); - VERIFY_IS_EQUAL(l_in.dimension(0), l_in_row_major.dimension(3)); - VERIFY_IS_EQUAL(l_in.dimension(1), l_in_row_major.dimension(2)); - VERIFY_IS_EQUAL(l_in.dimension(2), l_in_row_major.dimension(1)); - VERIFY_IS_EQUAL(l_in.dimension(3), l_in_row_major.dimension(0)); - - Tensor l_out_row_major = l_in_row_major.extract_image_patches(11, 11); + Tensor l_out_row_major = l_in.swap_layout().extract_image_patches(11, 11); VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 16); VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 128*128); VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 11); @@ -589,10 +583,8 @@ static void test_imagenet_patches() for (int r = 0; r < 11; ++r) { for (int d = 0; d < 3; ++d) { float expected = 0.0f; - float expected_row_major = 0.0f; if (r-5+i >= 0 && c-5+j >= 0 && r-5+i < 128 && c-5+j < 128) { expected = l_in(d, r-5+i, c-5+j, b); - expected_row_major = l_in_row_major(b, c-5+j, r-5+i, d); } // ColMajor if (l_out(d, r, c, patchId, b) != expected) { @@ -601,15 +593,13 @@ static void test_imagenet_patches() VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected); // RowMajor if (l_out_row_major(b, patchId, c, r, d) != - expected_row_major) { + expected) { std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl; } VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), - expected_row_major); - // Check that ColMajor and RowMajor agree. - VERIFY_IS_EQUAL(expected, expected_row_major); + expected); } } } @@ -628,8 +618,7 @@ static void test_imagenet_patches() VERIFY_IS_EQUAL(l_out.dimension(4), 32); // RowMajor - l_in_row_major = l_in.swap_layout(); - l_out_row_major = l_in_row_major.extract_image_patches(9, 9); + l_out_row_major = l_in.swap_layout().extract_image_patches(9, 9); VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32); VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 64*64); VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 9); @@ -644,10 +633,8 @@ static void test_imagenet_patches() for (int r = 0; r < 9; ++r) { for (int d = 0; d < 16; ++d) { float expected = 0.0f; - float expected_row_major = 0.0f; if (r-4+i >= 0 && c-4+j >= 0 && r-4+i < 64 && c-4+j < 64) { expected = l_in(d, r-4+i, c-4+j, b); - expected_row_major = l_in_row_major(b, c-4+j, r-4+i, d); } // ColMajor if (l_out(d, r, c, patchId, b) != expected) { @@ -655,12 +642,10 @@ static void test_imagenet_patches() } VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected); // RowMajor - if (l_out_row_major(b, patchId, c, r, d) != expected_row_major) { + if (l_out_row_major(b, patchId, c, r, d) != expected) { std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl; } - VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected_row_major); - // Check that ColMajor and RowMajor agree. - VERIFY_IS_EQUAL(expected, expected_row_major); + VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected); } } } @@ -679,8 +664,7 @@ static void test_imagenet_patches() VERIFY_IS_EQUAL(l_out.dimension(4), 32); // RowMajor - l_in_row_major = l_in.swap_layout(); - l_out_row_major = l_in_row_major.extract_image_patches(7, 7); + l_out_row_major = l_in.swap_layout().extract_image_patches(7, 7); VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32); VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 16*16); VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 7); @@ -695,10 +679,8 @@ static void test_imagenet_patches() for (int r = 0; r < 7; ++r) { for (int d = 0; d < 32; ++d) { float expected = 0.0f; - float expected_row_major = 0.0f; if (r-3+i >= 0 && c-3+j >= 0 && r-3+i < 16 && c-3+j < 16) { expected = l_in(d, r-3+i, c-3+j, b); - expected_row_major = l_in_row_major(b, c-3+j, r-3+i, d); } // ColMajor if (l_out(d, r, c, patchId, b) != expected) { @@ -706,12 +688,10 @@ static void test_imagenet_patches() } VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected); // RowMajor - if (l_out_row_major(b, patchId, c, r, d) != expected_row_major) { + if (l_out_row_major(b, patchId, c, r, d) != expected) { std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl; } - VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected_row_major); - // Check that ColMajor and RowMajor agree. - VERIFY_IS_EQUAL(expected, expected_row_major); + VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected); } } } @@ -730,8 +710,7 @@ static void test_imagenet_patches() VERIFY_IS_EQUAL(l_out.dimension(4), 32); // RowMajor - l_in_row_major = l_in.swap_layout(); - l_out_row_major = l_in_row_major.extract_image_patches(3, 3); + l_out_row_major = l_in.swap_layout().extract_image_patches(3, 3); VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32); VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 13*13); VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 3); @@ -746,10 +725,8 @@ static void test_imagenet_patches() for (int r = 0; r < 3; ++r) { for (int d = 0; d < 64; ++d) { float expected = 0.0f; - float expected_row_major = 0.0f; if (r-1+i >= 0 && c-1+j >= 0 && r-1+i < 13 && c-1+j < 13) { expected = l_in(d, r-1+i, c-1+j, b); - expected_row_major = l_in_row_major(b, c-1+j, r-1+i, d); } // ColMajor if (l_out(d, r, c, patchId, b) != expected) { @@ -757,12 +734,10 @@ static void test_imagenet_patches() } VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected); // RowMajor - if (l_out_row_major(b, patchId, c, r, d) != expected_row_major) { + if (l_out_row_major(b, patchId, c, r, d) != expected) { std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl; } - VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected_row_major); - // Check that ColMajor and RowMajor agree. - VERIFY_IS_EQUAL(expected, expected_row_major); + VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected); } } } -- cgit v1.2.3