aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2018-01-30 20:22:12 +0000
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2018-01-30 20:22:12 +0000
commit8f55956a570c79a0a8b76bf7134d5150727ca8f1 (patch)
tree0481f4fa5582a2f5f72d521130e08fd514d5d6ba /unsupported
parent09a16ba42fa1acc7bb0ace489ee51b3eb958ffa0 (diff)
parent3122477c8660f4e66e9cf4bf24e4fdfd6d56378c (diff)
Update the padding computation for PADDING_SAME to be consistent with TensorFlow.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h4
-rw-r--r--unsupported/test/cxx11_tensor_image_patch.cpp52
2 files changed, 56 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h b/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h
index 3c6a2e091..91d4ead28 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h
@@ -265,6 +265,10 @@ struct TensorEvaluator<const TensorImagePatchOp<Rows, Cols, ArgType>, Device>
// Calculate the padding
m_rowPaddingTop = ((m_outputRows - 1) * m_row_strides + m_patch_rows_eff - m_input_rows_eff) / 2;
m_colPaddingLeft = ((m_outputCols - 1) * m_col_strides + m_patch_cols_eff - m_input_cols_eff) / 2;
+ // The padding size calculation for PADDING_SAME has been updated to
+ // be consistent with how TensorFlow extracts its paddings.
+ m_rowPaddingTop = numext::maxi<Index>(0, m_rowPaddingTop);
+ m_colPaddingLeft = numext::maxi<Index>(0, m_colPaddingLeft);
break;
default:
eigen_assert(false && "unexpected padding");
diff --git a/unsupported/test/cxx11_tensor_image_patch.cpp b/unsupported/test/cxx11_tensor_image_patch.cpp
index 475c59651..105d32fb4 100644
--- a/unsupported/test/cxx11_tensor_image_patch.cpp
+++ b/unsupported/test/cxx11_tensor_image_patch.cpp
@@ -405,6 +405,57 @@ void test_patch_padding_same()
}
}
+// Verifies that SAME padding, when computed as negative values, will be clipped
+// to zero.
+void test_patch_padding_same_negative_padding_clip_to_zero() {
+ int input_depth = 1;
+ int input_rows = 15;
+ int input_cols = 1;
+ int input_batches = 1;
+ int ksize = 1; // Corresponds to the Rows and Cols for
+ // tensor.extract_image_patches<>.
+ int row_stride = 5;
+ int col_stride = 1;
+ // ColMajor
+ Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
+ // Initializes tensor with incrementing numbers.
+ for (int i = 0; i < tensor.size(); ++i) {
+ tensor.data()[i] = i + 1;
+ }
+ Tensor<float, 5> result = tensor.extract_image_patches(
+ ksize, ksize, row_stride, col_stride, 1, 1, PADDING_SAME);
+ // row padding will be computed as -2 originally and then be clipped to 0.
+ VERIFY_IS_EQUAL(result.coeff(0), 1.0f);
+ VERIFY_IS_EQUAL(result.coeff(1), 6.0f);
+ VERIFY_IS_EQUAL(result.coeff(2), 11.0f);
+
+ VERIFY_IS_EQUAL(result.dimension(0), input_depth); // depth
+ VERIFY_IS_EQUAL(result.dimension(1), ksize); // kernel rows
+ VERIFY_IS_EQUAL(result.dimension(2), ksize); // kernel cols
+ VERIFY_IS_EQUAL(result.dimension(3), 3); // number of patches
+ VERIFY_IS_EQUAL(result.dimension(4), input_batches); // number of batches
+
+ // RowMajor
+ Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
+ VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
+ VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
+ VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
+ VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
+
+ Tensor<float, 5, RowMajor> result_row_major =
+ tensor_row_major.extract_image_patches(ksize, ksize, row_stride,
+ col_stride, 1, 1, PADDING_SAME);
+ VERIFY_IS_EQUAL(result_row_major.coeff(0), 1.0f);
+ VERIFY_IS_EQUAL(result_row_major.coeff(1), 6.0f);
+ VERIFY_IS_EQUAL(result_row_major.coeff(2), 11.0f);
+
+ VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
+ VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
+ VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
+ VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
+ VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
+}
+
void test_patch_no_extra_dim()
{
Tensor<float, 3> tensor(2,3,5);
@@ -754,4 +805,5 @@ void test_cxx11_tensor_image_patch()
CALL_SUBTEST_4(test_patch_padding_valid_same_value());
CALL_SUBTEST_5(test_patch_padding_same());
CALL_SUBTEST_6(test_imagenet_patches());
+ CALL_SUBTEST_7(test_patch_padding_same_negative_padding_clip_to_zero());
}