diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-28 10:02:47 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-28 10:02:47 -0800 |
commit | 5a6ea4edf61b5626a781070c6342fc16606b490a (patch) | |
tree | 2e94aad11b5ca76e48e17bce25979694441879bc /unsupported/test/cxx11_tensor_reduction.cpp | |
parent | 9dfdbd7e568bd3aa9a4610986dcfc679b9ea425d (diff) |
Added more tests to cover tensor reductions
Diffstat (limited to 'unsupported/test/cxx11_tensor_reduction.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_reduction.cpp | 37 |
1 files changed, 35 insertions, 2 deletions
diff --git a/unsupported/test/cxx11_tensor_reduction.cpp b/unsupported/test/cxx11_tensor_reduction.cpp index 99e19eba4..5c3184833 100644 --- a/unsupported/test/cxx11_tensor_reduction.cpp +++ b/unsupported/test/cxx11_tensor_reduction.cpp @@ -369,6 +369,37 @@ static void test_innermost_first_dims() { } } +template <int DataLayout> +static void test_reduce_middle_dims() { + Tensor<float, 4, DataLayout> in(72, 53, 97, 113); + Tensor<float, 2, DataLayout> out(72, 53); + in.setRandom(); + +// Reduce on the innermost dimensions. +#if __cplusplus <= 199711L + array<int, 2> reduction_axis; + reduction_axis[0] = 1; + reduction_axis[1] = 2; +#else + // This triggers the use of packets for RowMajor. + Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>> reduction_axis; +#endif + + out = in.maximum(reduction_axis); + + for (int i = 0; i < 72; ++i) { + for (int j = 0; j < 113; ++j) { + float expected = -1e10f; + for (int k = 0; k < 53; ++k) { + for (int l = 0; l < 97; ++l) { + expected = (std::max)(expected, in(i, k, l, j)); + } + } + VERIFY_IS_APPROX(out(i, j), expected); + } + } +} + void test_cxx11_tensor_reduction() { CALL_SUBTEST(test_simple_reductions<ColMajor>()); CALL_SUBTEST(test_simple_reductions<RowMajor>()); @@ -380,8 +411,10 @@ void test_cxx11_tensor_reduction() { CALL_SUBTEST(test_tensor_maps<RowMajor>()); CALL_SUBTEST(test_static_dims<ColMajor>()); CALL_SUBTEST(test_static_dims<RowMajor>()); - CALL_SUBTEST(test_innermost_last_dims<RowMajor>()); CALL_SUBTEST(test_innermost_last_dims<ColMajor>()); - CALL_SUBTEST(test_innermost_first_dims<RowMajor>()); + CALL_SUBTEST(test_innermost_last_dims<RowMajor>()); CALL_SUBTEST(test_innermost_first_dims<ColMajor>()); + CALL_SUBTEST(test_innermost_first_dims<RowMajor>()); + CALL_SUBTEST(test_reduce_middle_dims<ColMajor>()); + CALL_SUBTEST(test_reduce_middle_dims<RowMajor>()); } |