diff options
Diffstat (limited to 'unsupported/test')
-rw-r--r-- | unsupported/test/CMakeLists.txt | 1 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_trace.cpp | 171 |
2 files changed, 172 insertions, 0 deletions
diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt index e639e7056..22647cadd 100644 --- a/unsupported/test/CMakeLists.txt +++ b/unsupported/test/CMakeLists.txt @@ -227,6 +227,7 @@ if(EIGEN_TEST_CXX11) ei_add_test(cxx11_tensor_fft) ei_add_test(cxx11_tensor_ifft) ei_add_test(cxx11_tensor_scan) + ei_add_test(cxx11_tensor_trace) endif() diff --git a/unsupported/test/cxx11_tensor_trace.cpp b/unsupported/test/cxx11_tensor_trace.cpp new file mode 100644 index 000000000..340d1211c --- /dev/null +++ b/unsupported/test/cxx11_tensor_trace.cpp @@ -0,0 +1,171 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include <Eigen/CXX11/Tensor> + +using Eigen::Tensor; +using Eigen::array; + +template <int DataLayout> +static void test_0D_trace() { + Tensor<float, 0, DataLayout> tensor; + tensor.setRandom(); + array<ptrdiff_t, 0> dims; + Tensor<float, 0, DataLayout> result = tensor.trace(dims); + VERIFY_IS_EQUAL(result(), tensor()); +} + + +template <int DataLayout> +static void test_all_dimensions_trace() { + Tensor<float, 3, DataLayout> tensor1(5, 5, 5); + tensor1.setRandom(); + Tensor<float, 0, DataLayout> result1 = tensor1.trace(); + VERIFY_IS_EQUAL(result1.rank(), 0); + float sum = 0.0f; + for (int i = 0; i < 5; ++i) { + sum += tensor1(i, i, i); + } + VERIFY_IS_EQUAL(result1(), sum); + + Tensor<float, 5, DataLayout> tensor2(7, 7, 7, 7, 7); + array<ptrdiff_t, 5> dims({{2, 1, 0, 3, 4}}); + Tensor<float, 0, DataLayout> result2 = tensor2.trace(dims); + VERIFY_IS_EQUAL(result2.rank(), 0); + sum = 0.0f; + for (int i = 0; i < 7; ++i) { + sum += tensor2(i, i, i, i, i); + } + VERIFY_IS_EQUAL(result2(), sum); +} + + +template <int DataLayout> +static void test_simple_trace() { + Tensor<float, 3, DataLayout> tensor1(3, 5, 3); + tensor1.setRandom(); + array<ptrdiff_t, 2> dims1({{0, 2}}); + Tensor<float, 1, DataLayout> result1 = tensor1.trace(dims1); + VERIFY_IS_EQUAL(result1.rank(), 1); + VERIFY_IS_EQUAL(result1.dimension(0), 5); + float sum = 0.0f; + for (int i = 0; i < 5; ++i) { + sum = 0.0f; + for (int j = 0; j < 3; ++j) { + sum += tensor1(j, i, j); + } + VERIFY_IS_EQUAL(result1(i), sum); + } + + Tensor<float, 4, DataLayout> tensor2(5, 5, 7, 7); + tensor2.setRandom(); + array<ptrdiff_t, 2> dims2({{2, 3}}); + Tensor<float, 2, DataLayout> result2 = tensor2.trace(dims2); + VERIFY_IS_EQUAL(result2.rank(), 2); + VERIFY_IS_EQUAL(result2.dimension(0), 5); + VERIFY_IS_EQUAL(result2.dimension(1), 5); + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 5; ++j) { + sum = 0.0f; + for (int k = 0; k < 7; ++k) { + sum += tensor2(i, j, k, k); + } + VERIFY_IS_EQUAL(result2(i, j), sum); + } + } + + array<ptrdiff_t, 2> dims3({{1, 0}}); + Tensor<float, 2, DataLayout> result3 = tensor2.trace(dims3); + VERIFY_IS_EQUAL(result3.rank(), 2); + VERIFY_IS_EQUAL(result3.dimension(0), 7); + VERIFY_IS_EQUAL(result3.dimension(1), 7); + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 7; ++j) { + sum = 0.0f; + for (int k = 0; k < 5; ++k) { + sum += tensor2(k, k, i, j); + } + VERIFY_IS_EQUAL(result3(i, j), sum); + } + } + + Tensor<float, 5, DataLayout> tensor3(3, 7, 3, 7, 3); + tensor3.setRandom(); + array<ptrdiff_t, 3> dims4({{0, 2, 4}}); + Tensor<float, 2, DataLayout> result4 = tensor3.trace(dims4); + VERIFY_IS_EQUAL(result4.rank(), 2); + VERIFY_IS_EQUAL(result4.dimension(0), 7); + VERIFY_IS_EQUAL(result4.dimension(1), 7); + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 7; ++j) { + sum = 0.0f; + for (int k = 0; k < 3; ++k) { + sum += tensor3(k, i, k, j, k); + } + VERIFY_IS_EQUAL(result4(i, j), sum); + } + } + + Tensor<float, 5, DataLayout> tensor4(3, 7, 4, 7, 5); + tensor4.setRandom(); + array<ptrdiff_t, 2> dims5({{1, 3}}); + Tensor<float, 3, DataLayout> result5 = tensor4.trace(dims5); + VERIFY_IS_EQUAL(result5.rank(), 3); + VERIFY_IS_EQUAL(result5.dimension(0), 3); + VERIFY_IS_EQUAL(result5.dimension(1), 4); + VERIFY_IS_EQUAL(result5.dimension(2), 5); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 5; ++k) { + sum = 0.0f; + for (int l = 0; l < 7; ++l) { + sum += tensor4(i, l, j, l, k); + } + VERIFY_IS_EQUAL(result5(i, j, k), sum); + } + } + } +} + + +template<int DataLayout> +static void test_trace_in_expr() { + Tensor<float, 4, DataLayout> tensor(2, 3, 5, 3); + tensor.setRandom(); + array<ptrdiff_t, 2> dims({{1, 3}}); + Tensor<float, 2, DataLayout> result(2, 5); + result = result.constant(1.0f) - tensor.trace(dims); + VERIFY_IS_EQUAL(result.rank(), 2); + VERIFY_IS_EQUAL(result.dimension(0), 2); + VERIFY_IS_EQUAL(result.dimension(1), 5); + float sum = 0.0f; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 5; ++j) { + sum = 0.0f; + for (int k = 0; k < 3; ++k) { + sum += tensor(i, k, j, k); + } + VERIFY_IS_EQUAL(result(i, j), 1.0f - sum); + } + } +} + + +void test_cxx11_tensor_trace() { + CALL_SUBTEST(test_0D_trace<ColMajor>()); + CALL_SUBTEST(test_0D_trace<RowMajor>()); + CALL_SUBTEST(test_all_dimensions_trace<ColMajor>()); + CALL_SUBTEST(test_all_dimensions_trace<RowMajor>()); + CALL_SUBTEST(test_simple_trace<ColMajor>()); + CALL_SUBTEST(test_simple_trace<RowMajor>()); + CALL_SUBTEST(test_trace_in_expr<ColMajor>()); + CALL_SUBTEST(test_trace_in_expr<RowMajor>()); +} |