aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_ref.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-03-16 13:05:00 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-03-16 13:05:00 -0700
commitf218c0181d44d7dd129a77108ad6ec063cfbd6cc (patch)
tree1e19f1556a842dbdc98e691ef207b55cbf1f5989 /unsupported/test/cxx11_tensor_ref.cpp
parent35c3a8bb84778a81b2f90fdea10eadeae16863aa (diff)
Fixes the Lvalue computation by actually setting the LvalueBit properly when instantiating tensors of const T. Added a test to check the fix.
Diffstat (limited to 'unsupported/test/cxx11_tensor_ref.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_ref.cpp40
1 files changed, 40 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_ref.cpp b/unsupported/test/cxx11_tensor_ref.cpp
index aa369f278..c7b5ecddb 100644
--- a/unsupported/test/cxx11_tensor_ref.cpp
+++ b/unsupported/test/cxx11_tensor_ref.cpp
@@ -196,6 +196,45 @@ static void test_coeff_ref()
}
+static void test_nested_ops_with_ref()
+{
+ Tensor<float, 4> t(2, 3, 5, 7);
+ t.setRandom();
+ TensorMap<Tensor<const float, 4> > m(t.data(), 2, 3, 5, 7);
+ array<pair<ptrdiff_t, ptrdiff_t>, 4> paddings;
+ paddings[0] = make_pair(0, 0);
+ paddings[1] = make_pair(2, 1);
+ paddings[2] = make_pair(3, 4);
+ paddings[3] = make_pair(0, 0);
+ Eigen::DSizes<Eigen::DenseIndex, 4> shuffle_dims{0, 1, 2, 3};
+ TensorRef<Tensor<const float, 4> > ref(m.pad(paddings));
+ array<pair<ptrdiff_t, ptrdiff_t>, 4> trivial;
+ trivial[0] = make_pair(0, 0);
+ trivial[1] = make_pair(0, 0);
+ trivial[2] = make_pair(0, 0);
+ trivial[3] = make_pair(0, 0);
+ Tensor<float, 4> padded = ref.shuffle(shuffle_dims).pad(trivial);
+ VERIFY_IS_EQUAL(padded.dimension(0), 2+0);
+ VERIFY_IS_EQUAL(padded.dimension(1), 3+3);
+ VERIFY_IS_EQUAL(padded.dimension(2), 5+7);
+ VERIFY_IS_EQUAL(padded.dimension(3), 7+0);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 6; ++j) {
+ for (int k = 0; k < 12; ++k) {
+ for (int l = 0; l < 7; ++l) {
+ if (j >= 2 && j < 5 && k >= 3 && k < 8) {
+ VERIFY_IS_EQUAL(padded(i,j,k,l), t(i,j-2,k-3,l));
+ } else {
+ VERIFY_IS_EQUAL(padded(i,j,k,l), 0.0f);
+ }
+ }
+ }
+ }
+ }
+}
+
+
void test_cxx11_tensor_ref()
{
CALL_SUBTEST(test_simple_lvalue_ref());
@@ -205,4 +244,5 @@ void test_cxx11_tensor_ref()
CALL_SUBTEST(test_ref_of_ref());
CALL_SUBTEST(test_ref_in_expr());
CALL_SUBTEST(test_coeff_ref());
+ CALL_SUBTEST(test_nested_ops_with_ref());
}