aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_padding.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-09-05 07:47:43 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-09-05 07:47:43 -0700
commit74db22455ae0172faaae91321da0b303bb82369d (patch)
tree773686acd6d7e8abc5a9dfe17d13ca9138f6108e /unsupported/test/cxx11_tensor_padding.cpp
parent1abe4ed14c0012d85e833c5f507f282cf26edc36 (diff)
Misc fixes.
Diffstat (limited to 'unsupported/test/cxx11_tensor_padding.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_padding.cpp38
1 files changed, 36 insertions, 2 deletions
diff --git a/unsupported/test/cxx11_tensor_padding.cpp b/unsupported/test/cxx11_tensor_padding.cpp
index cb010f512..6f74216dd 100644
--- a/unsupported/test/cxx11_tensor_padding.cpp
+++ b/unsupported/test/cxx11_tensor_padding.cpp
@@ -37,9 +37,42 @@ static void test_simple_padding()
for (int k = 0; k < 12; ++k) {
for (int l = 0; l < 7; ++l) {
if (j >= 2 && j < 5 && k >= 3 && k < 8) {
- VERIFY_IS_EQUAL(tensor(i,j-2,k-3,l), padded(i,j,k,l));
+ VERIFY_IS_EQUAL(padded(i,j,k,l), tensor(i,j-2,k-3,l));
} else {
- VERIFY_IS_EQUAL(0.0f, padded(i,j,k,l));
+ VERIFY_IS_EQUAL(padded(i,j,k,l), 0.0f);
+ }
+ }
+ }
+ }
+ }
+}
+
+static void test_padded_expr()
+{
+ Tensor<float, 4> tensor(2,3,5,7);
+ tensor.setRandom();
+
+ array<std::pair<ptrdiff_t, ptrdiff_t>, 4> paddings;
+ paddings[0] = std::make_pair(0, 0);
+ paddings[1] = std::make_pair(2, 1);
+ paddings[2] = std::make_pair(3, 4);
+ paddings[3] = std::make_pair(0, 0);
+
+ Eigen::DSizes<ptrdiff_t, 2> reshape_dims;
+ reshape_dims[0] = 12;
+ reshape_dims[1] = 84;
+
+ Tensor<float, 2> result;
+ result = tensor.pad(paddings).reshape(reshape_dims);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 6; ++j) {
+ for (int k = 0; k < 12; ++k) {
+ for (int l = 0; l < 7; ++l) {
+ if (j >= 2 && j < 5 && k >= 3 && k < 8) {
+ VERIFY_IS_EQUAL(result(i+2*j,k+12*l), tensor(i,j-2,k-3,l));
+ } else {
+ VERIFY_IS_EQUAL(result(i+2*j,k+12*l), 0.0f);
}
}
}
@@ -51,4 +84,5 @@ static void test_simple_padding()
void test_cxx11_tensor_padding()
{
CALL_SUBTEST(test_simple_padding());
+ CALL_SUBTEST(test_padded_expr());
}