From 3122477c8660f4e66e9cf4bf24e4fdfd6d56378c Mon Sep 17 00:00:00 2001 From: Yangzihao Wang Date: Tue, 12 Dec 2017 11:15:24 -0800 Subject: Update the padding computation for PADDING_SAME to be consistent with TensorFlow. --- unsupported/test/cxx11_tensor_image_patch.cpp | 52 +++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) (limited to 'unsupported/test/cxx11_tensor_image_patch.cpp') diff --git a/unsupported/test/cxx11_tensor_image_patch.cpp b/unsupported/test/cxx11_tensor_image_patch.cpp index 475c59651..105d32fb4 100644 --- a/unsupported/test/cxx11_tensor_image_patch.cpp +++ b/unsupported/test/cxx11_tensor_image_patch.cpp @@ -405,6 +405,57 @@ void test_patch_padding_same() } } +// Verifies that SAME padding, when computed as negative values, will be clipped +// to zero. +void test_patch_padding_same_negative_padding_clip_to_zero() { + int input_depth = 1; + int input_rows = 15; + int input_cols = 1; + int input_batches = 1; + int ksize = 1; // Corresponds to the Rows and Cols for + // tensor.extract_image_patches<>. + int row_stride = 5; + int col_stride = 1; + // ColMajor + Tensor tensor(input_depth, input_rows, input_cols, input_batches); + // Initializes tensor with incrementing numbers. + for (int i = 0; i < tensor.size(); ++i) { + tensor.data()[i] = i + 1; + } + Tensor result = tensor.extract_image_patches( + ksize, ksize, row_stride, col_stride, 1, 1, PADDING_SAME); + // row padding will be computed as -2 originally and then be clipped to 0. + VERIFY_IS_EQUAL(result.coeff(0), 1.0f); + VERIFY_IS_EQUAL(result.coeff(1), 6.0f); + VERIFY_IS_EQUAL(result.coeff(2), 11.0f); + + VERIFY_IS_EQUAL(result.dimension(0), input_depth); // depth + VERIFY_IS_EQUAL(result.dimension(1), ksize); // kernel rows + VERIFY_IS_EQUAL(result.dimension(2), ksize); // kernel cols + VERIFY_IS_EQUAL(result.dimension(3), 3); // number of patches + VERIFY_IS_EQUAL(result.dimension(4), input_batches); // number of batches + + // RowMajor + Tensor tensor_row_major = tensor.swap_layout(); + VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3)); + VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2)); + VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1)); + VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0)); + + Tensor result_row_major = + tensor_row_major.extract_image_patches(ksize, ksize, row_stride, + col_stride, 1, 1, PADDING_SAME); + VERIFY_IS_EQUAL(result_row_major.coeff(0), 1.0f); + VERIFY_IS_EQUAL(result_row_major.coeff(1), 6.0f); + VERIFY_IS_EQUAL(result_row_major.coeff(2), 11.0f); + + VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4)); + VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3)); + VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2)); + VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1)); + VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0)); +} + void test_patch_no_extra_dim() { Tensor tensor(2,3,5); @@ -754,4 +805,5 @@ void test_cxx11_tensor_image_patch() CALL_SUBTEST_4(test_patch_padding_valid_same_value()); CALL_SUBTEST_5(test_patch_padding_same()); CALL_SUBTEST_6(test_imagenet_patches()); + CALL_SUBTEST_7(test_patch_padding_same_negative_padding_clip_to_zero()); } -- cgit v1.2.3