aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_image_patch.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-04 21:08:22 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-04 21:08:22 -0700
commit62b710072e282ad70bbcb38468367f7f99232d32 (patch)
tree8c305556d3adb4a0a8aebe039b201aae03c52567 /unsupported/test/cxx11_tensor_image_patch.cpp
parentdd2b45feede628b66b2b9cf07d44fedf666da4fa (diff)
Reduced the memory footprint of the cxx11_tensor_image_patch test
Diffstat (limited to 'unsupported/test/cxx11_tensor_image_patch.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_image_patch.cpp49
1 files changed, 12 insertions, 37 deletions
diff --git a/unsupported/test/cxx11_tensor_image_patch.cpp b/unsupported/test/cxx11_tensor_image_patch.cpp
index 5d6a49181..988b01481 100644
--- a/unsupported/test/cxx11_tensor_image_patch.cpp
+++ b/unsupported/test/cxx11_tensor_image_patch.cpp
@@ -568,13 +568,7 @@ static void test_imagenet_patches()
VERIFY_IS_EQUAL(l_out.dimension(4), 16);
// RowMajor
- Tensor<float, 4, RowMajor> l_in_row_major = l_in.swap_layout();
- VERIFY_IS_EQUAL(l_in.dimension(0), l_in_row_major.dimension(3));
- VERIFY_IS_EQUAL(l_in.dimension(1), l_in_row_major.dimension(2));
- VERIFY_IS_EQUAL(l_in.dimension(2), l_in_row_major.dimension(1));
- VERIFY_IS_EQUAL(l_in.dimension(3), l_in_row_major.dimension(0));
-
- Tensor<float, 5, RowMajor> l_out_row_major = l_in_row_major.extract_image_patches(11, 11);
+ Tensor<float, 5, RowMajor> l_out_row_major = l_in.swap_layout().extract_image_patches(11, 11);
VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 16);
VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 128*128);
VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 11);
@@ -589,10 +583,8 @@ static void test_imagenet_patches()
for (int r = 0; r < 11; ++r) {
for (int d = 0; d < 3; ++d) {
float expected = 0.0f;
- float expected_row_major = 0.0f;
if (r-5+i >= 0 && c-5+j >= 0 && r-5+i < 128 && c-5+j < 128) {
expected = l_in(d, r-5+i, c-5+j, b);
- expected_row_major = l_in_row_major(b, c-5+j, r-5+i, d);
}
// ColMajor
if (l_out(d, r, c, patchId, b) != expected) {
@@ -601,15 +593,13 @@ static void test_imagenet_patches()
VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
// RowMajor
if (l_out_row_major(b, patchId, c, r, d) !=
- expected_row_major) {
+ expected) {
std::cout << "Mismatch detected at index i=" << i << " j=" << j
<< " r=" << r << " c=" << c << " d=" << d << " b=" << b
<< std::endl;
}
VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d),
- expected_row_major);
- // Check that ColMajor and RowMajor agree.
- VERIFY_IS_EQUAL(expected, expected_row_major);
+ expected);
}
}
}
@@ -628,8 +618,7 @@ static void test_imagenet_patches()
VERIFY_IS_EQUAL(l_out.dimension(4), 32);
// RowMajor
- l_in_row_major = l_in.swap_layout();
- l_out_row_major = l_in_row_major.extract_image_patches(9, 9);
+ l_out_row_major = l_in.swap_layout().extract_image_patches(9, 9);
VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32);
VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 64*64);
VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 9);
@@ -644,10 +633,8 @@ static void test_imagenet_patches()
for (int r = 0; r < 9; ++r) {
for (int d = 0; d < 16; ++d) {
float expected = 0.0f;
- float expected_row_major = 0.0f;
if (r-4+i >= 0 && c-4+j >= 0 && r-4+i < 64 && c-4+j < 64) {
expected = l_in(d, r-4+i, c-4+j, b);
- expected_row_major = l_in_row_major(b, c-4+j, r-4+i, d);
}
// ColMajor
if (l_out(d, r, c, patchId, b) != expected) {
@@ -655,12 +642,10 @@ static void test_imagenet_patches()
}
VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
// RowMajor
- if (l_out_row_major(b, patchId, c, r, d) != expected_row_major) {
+ if (l_out_row_major(b, patchId, c, r, d) != expected) {
std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
}
- VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected_row_major);
- // Check that ColMajor and RowMajor agree.
- VERIFY_IS_EQUAL(expected, expected_row_major);
+ VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected);
}
}
}
@@ -679,8 +664,7 @@ static void test_imagenet_patches()
VERIFY_IS_EQUAL(l_out.dimension(4), 32);
// RowMajor
- l_in_row_major = l_in.swap_layout();
- l_out_row_major = l_in_row_major.extract_image_patches(7, 7);
+ l_out_row_major = l_in.swap_layout().extract_image_patches(7, 7);
VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32);
VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 16*16);
VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 7);
@@ -695,10 +679,8 @@ static void test_imagenet_patches()
for (int r = 0; r < 7; ++r) {
for (int d = 0; d < 32; ++d) {
float expected = 0.0f;
- float expected_row_major = 0.0f;
if (r-3+i >= 0 && c-3+j >= 0 && r-3+i < 16 && c-3+j < 16) {
expected = l_in(d, r-3+i, c-3+j, b);
- expected_row_major = l_in_row_major(b, c-3+j, r-3+i, d);
}
// ColMajor
if (l_out(d, r, c, patchId, b) != expected) {
@@ -706,12 +688,10 @@ static void test_imagenet_patches()
}
VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
// RowMajor
- if (l_out_row_major(b, patchId, c, r, d) != expected_row_major) {
+ if (l_out_row_major(b, patchId, c, r, d) != expected) {
std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
}
- VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected_row_major);
- // Check that ColMajor and RowMajor agree.
- VERIFY_IS_EQUAL(expected, expected_row_major);
+ VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected);
}
}
}
@@ -730,8 +710,7 @@ static void test_imagenet_patches()
VERIFY_IS_EQUAL(l_out.dimension(4), 32);
// RowMajor
- l_in_row_major = l_in.swap_layout();
- l_out_row_major = l_in_row_major.extract_image_patches(3, 3);
+ l_out_row_major = l_in.swap_layout().extract_image_patches(3, 3);
VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32);
VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 13*13);
VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 3);
@@ -746,10 +725,8 @@ static void test_imagenet_patches()
for (int r = 0; r < 3; ++r) {
for (int d = 0; d < 64; ++d) {
float expected = 0.0f;
- float expected_row_major = 0.0f;
if (r-1+i >= 0 && c-1+j >= 0 && r-1+i < 13 && c-1+j < 13) {
expected = l_in(d, r-1+i, c-1+j, b);
- expected_row_major = l_in_row_major(b, c-1+j, r-1+i, d);
}
// ColMajor
if (l_out(d, r, c, patchId, b) != expected) {
@@ -757,12 +734,10 @@ static void test_imagenet_patches()
}
VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
// RowMajor
- if (l_out_row_major(b, patchId, c, r, d) != expected_row_major) {
+ if (l_out_row_major(b, patchId, c, r, d) != expected) {
std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
}
- VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected_row_major);
- // Check that ColMajor and RowMajor agree.
- VERIFY_IS_EQUAL(expected, expected_row_major);
+ VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected);
}
}
}