diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2018-01-30 20:22:12 +0000 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2018-01-30 20:22:12 +0000 |
commit | 8f55956a570c79a0a8b76bf7134d5150727ca8f1 (patch) | |
tree | 0481f4fa5582a2f5f72d521130e08fd514d5d6ba /unsupported | |
parent | 09a16ba42fa1acc7bb0ace489ee51b3eb958ffa0 (diff) | |
parent | 3122477c8660f4e66e9cf4bf24e4fdfd6d56378c (diff) |
Update the padding computation for PADDING_SAME to be consistent with TensorFlow.
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h | 4 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_image_patch.cpp | 52 |
2 files changed, 56 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h b/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h index 3c6a2e091..91d4ead28 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h @@ -265,6 +265,10 @@ struct TensorEvaluator<const TensorImagePatchOp<Rows, Cols, ArgType>, Device> // Calculate the padding m_rowPaddingTop = ((m_outputRows - 1) * m_row_strides + m_patch_rows_eff - m_input_rows_eff) / 2; m_colPaddingLeft = ((m_outputCols - 1) * m_col_strides + m_patch_cols_eff - m_input_cols_eff) / 2; + // The padding size calculation for PADDING_SAME has been updated to + // be consistent with how TensorFlow extracts its paddings. + m_rowPaddingTop = numext::maxi<Index>(0, m_rowPaddingTop); + m_colPaddingLeft = numext::maxi<Index>(0, m_colPaddingLeft); break; default: eigen_assert(false && "unexpected padding"); diff --git a/unsupported/test/cxx11_tensor_image_patch.cpp b/unsupported/test/cxx11_tensor_image_patch.cpp index 475c59651..105d32fb4 100644 --- a/unsupported/test/cxx11_tensor_image_patch.cpp +++ b/unsupported/test/cxx11_tensor_image_patch.cpp @@ -405,6 +405,57 @@ void test_patch_padding_same() } } +// Verifies that SAME padding, when computed as negative values, will be clipped +// to zero. +void test_patch_padding_same_negative_padding_clip_to_zero() { + int input_depth = 1; + int input_rows = 15; + int input_cols = 1; + int input_batches = 1; + int ksize = 1; // Corresponds to the Rows and Cols for + // tensor.extract_image_patches<>. + int row_stride = 5; + int col_stride = 1; + // ColMajor + Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches); + // Initializes tensor with incrementing numbers. + for (int i = 0; i < tensor.size(); ++i) { + tensor.data()[i] = i + 1; + } + Tensor<float, 5> result = tensor.extract_image_patches( + ksize, ksize, row_stride, col_stride, 1, 1, PADDING_SAME); + // row padding will be computed as -2 originally and then be clipped to 0. + VERIFY_IS_EQUAL(result.coeff(0), 1.0f); + VERIFY_IS_EQUAL(result.coeff(1), 6.0f); + VERIFY_IS_EQUAL(result.coeff(2), 11.0f); + + VERIFY_IS_EQUAL(result.dimension(0), input_depth); // depth + VERIFY_IS_EQUAL(result.dimension(1), ksize); // kernel rows + VERIFY_IS_EQUAL(result.dimension(2), ksize); // kernel cols + VERIFY_IS_EQUAL(result.dimension(3), 3); // number of patches + VERIFY_IS_EQUAL(result.dimension(4), input_batches); // number of batches + + // RowMajor + Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout(); + VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3)); + VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2)); + VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1)); + VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0)); + + Tensor<float, 5, RowMajor> result_row_major = + tensor_row_major.extract_image_patches(ksize, ksize, row_stride, + col_stride, 1, 1, PADDING_SAME); + VERIFY_IS_EQUAL(result_row_major.coeff(0), 1.0f); + VERIFY_IS_EQUAL(result_row_major.coeff(1), 6.0f); + VERIFY_IS_EQUAL(result_row_major.coeff(2), 11.0f); + + VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4)); + VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3)); + VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2)); + VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1)); + VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0)); +} + void test_patch_no_extra_dim() { Tensor<float, 3> tensor(2,3,5); @@ -754,4 +805,5 @@ void test_cxx11_tensor_image_patch() CALL_SUBTEST_4(test_patch_padding_valid_same_value()); CALL_SUBTEST_5(test_patch_padding_same()); CALL_SUBTEST_6(test_imagenet_patches()); + CALL_SUBTEST_7(test_patch_padding_same_negative_padding_clip_to_zero()); } |