From 0d15ad80195ec5cd33f057068e34aa7e1dc2b783 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 5 Nov 2015 14:22:30 -0800 Subject: Updated the regressions tests that cover full reductions --- unsupported/test/cxx11_tensor_reduction.cpp | 111 +++++++++++++++++++--------- 1 file changed, 76 insertions(+), 35 deletions(-) (limited to 'unsupported/test/cxx11_tensor_reduction.cpp') diff --git a/unsupported/test/cxx11_tensor_reduction.cpp b/unsupported/test/cxx11_tensor_reduction.cpp index e8180c061..0ec316991 100644 --- a/unsupported/test/cxx11_tensor_reduction.cpp +++ b/unsupported/test/cxx11_tensor_reduction.cpp @@ -13,6 +13,45 @@ using Eigen::Tensor; +template +static void test_trivial_reductions() { + { + Tensor tensor; + tensor.setRandom(); + array reduction_axis; + + Tensor result = tensor.sum(reduction_axis); + VERIFY_IS_EQUAL(result(), tensor()); + } + + { + Tensor tensor(7); + tensor.setRandom(); + array reduction_axis; + + Tensor result = tensor.sum(reduction_axis); + VERIFY_IS_EQUAL(result.dimension(0), 7); + for (int i = 0; i < 7; ++i) { + VERIFY_IS_EQUAL(result(i), tensor(i)); + } + } + + { + Tensor tensor(2, 3); + tensor.setRandom(); + array reduction_axis; + + Tensor result = tensor.sum(reduction_axis); + VERIFY_IS_EQUAL(result.dimension(0), 2); + VERIFY_IS_EQUAL(result.dimension(1), 3); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(result(i, j), tensor(i, j)); + } + } + } +} + template static void test_simple_reductions() { Tensor tensor(2, 3, 5, 7); @@ -37,18 +76,18 @@ static void test_simple_reductions() { } { - Tensor sum1 = tensor.sum(); - VERIFY_IS_EQUAL(sum1.dimension(0), 1); + Tensor sum1 = tensor.sum(); + VERIFY_IS_EQUAL(sum1.rank(), 0); array reduction_axis4; reduction_axis4[0] = 0; reduction_axis4[1] = 1; reduction_axis4[2] = 2; reduction_axis4[3] = 3; - Tensor sum2 = tensor.sum(reduction_axis4); - VERIFY_IS_EQUAL(sum2.dimension(0), 1); + Tensor sum2 = tensor.sum(reduction_axis4); + VERIFY_IS_EQUAL(sum2.rank(), 0); - VERIFY_IS_APPROX(sum1(0), sum2(0)); + VERIFY_IS_APPROX(sum1(), sum2()); } reduction_axis2[0] = 0; @@ -69,18 +108,18 @@ static void test_simple_reductions() { } { - Tensor prod1 = tensor.prod(); - VERIFY_IS_EQUAL(prod1.dimension(0), 1); + Tensor prod1 = tensor.prod(); + VERIFY_IS_EQUAL(prod1.rank(), 0); array reduction_axis4; reduction_axis4[0] = 0; reduction_axis4[1] = 1; reduction_axis4[2] = 2; reduction_axis4[3] = 3; - Tensor prod2 = tensor.prod(reduction_axis4); - VERIFY_IS_EQUAL(prod2.dimension(0), 1); + Tensor prod2 = tensor.prod(reduction_axis4); + VERIFY_IS_EQUAL(prod2.rank(), 0); - VERIFY_IS_APPROX(prod1(0), prod2(0)); + VERIFY_IS_APPROX(prod1(), prod2()); } reduction_axis2[0] = 0; @@ -101,18 +140,18 @@ static void test_simple_reductions() { } { - Tensor max1 = tensor.maximum(); - VERIFY_IS_EQUAL(max1.dimension(0), 1); + Tensor max1 = tensor.maximum(); + VERIFY_IS_EQUAL(max1.rank(), 0); array reduction_axis4; reduction_axis4[0] = 0; reduction_axis4[1] = 1; reduction_axis4[2] = 2; reduction_axis4[3] = 3; - Tensor max2 = tensor.maximum(reduction_axis4); - VERIFY_IS_EQUAL(max2.dimension(0), 1); + Tensor max2 = tensor.maximum(reduction_axis4); + VERIFY_IS_EQUAL(max2.rank(), 0); - VERIFY_IS_APPROX(max1(0), max2(0)); + VERIFY_IS_APPROX(max1(), max2()); } reduction_axis2[0] = 0; @@ -133,18 +172,18 @@ static void test_simple_reductions() { } { - Tensor min1 = tensor.minimum(); - VERIFY_IS_EQUAL(min1.dimension(0), 1); + Tensor min1 = tensor.minimum(); + VERIFY_IS_EQUAL(min1.rank(), 0); array reduction_axis4; reduction_axis4[0] = 0; reduction_axis4[1] = 1; reduction_axis4[2] = 2; reduction_axis4[3] = 3; - Tensor min2 = tensor.minimum(reduction_axis4); - VERIFY_IS_EQUAL(min2.dimension(0), 1); + Tensor min2 = tensor.minimum(reduction_axis4); + VERIFY_IS_EQUAL(min2.rank(), 0); - VERIFY_IS_APPROX(min1(0), min2(0)); + VERIFY_IS_APPROX(min1(), min2()); } reduction_axis2[0] = 0; @@ -167,35 +206,35 @@ static void test_simple_reductions() { } { - Tensor mean1 = tensor.mean(); - VERIFY_IS_EQUAL(mean1.dimension(0), 1); + Tensor mean1 = tensor.mean(); + VERIFY_IS_EQUAL(mean1.rank(), 0); array reduction_axis4; reduction_axis4[0] = 0; reduction_axis4[1] = 1; reduction_axis4[2] = 2; reduction_axis4[3] = 3; - Tensor mean2 = tensor.mean(reduction_axis4); - VERIFY_IS_EQUAL(mean2.dimension(0), 1); + Tensor mean2 = tensor.mean(reduction_axis4); + VERIFY_IS_EQUAL(mean2.rank(), 0); - VERIFY_IS_APPROX(mean1(0), mean2(0)); + VERIFY_IS_APPROX(mean1(), mean2()); } { Tensor ints(10); std::iota(ints.data(), ints.data() + ints.dimension(0), 0); - TensorFixedSize > all; + TensorFixedSize > all; all = ints.all(); - VERIFY(!all(0)); + VERIFY(!all()); all = (ints >= ints.constant(0)).all(); - VERIFY(all(0)); + VERIFY(all()); - TensorFixedSize > any; + TensorFixedSize > any; any = (ints > ints.constant(10)).any(); - VERIFY(!any(0)); + VERIFY(!any()); any = (ints < ints.constant(1)).any(); - VERIFY(any(0)); + VERIFY(any()); } } @@ -207,8 +246,8 @@ static void test_full_reductions() { reduction_axis[0] = 0; reduction_axis[1] = 1; - Tensor result = tensor.sum(reduction_axis); - VERIFY_IS_EQUAL(result.dimension(0), 1); + Tensor result = tensor.sum(reduction_axis); + VERIFY_IS_EQUAL(result.rank(), 0); float sum = 0.0f; for (int i = 0; i < 2; ++i) { @@ -219,7 +258,7 @@ static void test_full_reductions() { VERIFY_IS_APPROX(result(0), sum); result = tensor.square().sum(reduction_axis).sqrt(); - VERIFY_IS_EQUAL(result.dimension(0), 1); + VERIFY_IS_EQUAL(result.rank(), 0); sum = 0.0f; for (int i = 0; i < 2; ++i) { @@ -227,7 +266,7 @@ static void test_full_reductions() { sum += tensor(i, j) * tensor(i, j); } } - VERIFY_IS_APPROX(result(0), sqrtf(sum)); + VERIFY_IS_APPROX(result(), sqrtf(sum)); } struct UserReducer { @@ -418,6 +457,8 @@ static void test_reduce_middle_dims() { } void test_cxx11_tensor_reduction() { + CALL_SUBTEST(test_trivial_reductions()); + CALL_SUBTEST(test_trivial_reductions()); CALL_SUBTEST(test_simple_reductions()); CALL_SUBTEST(test_simple_reductions()); CALL_SUBTEST(test_full_reductions()); -- cgit v1.2.3