aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_trace.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2017-07-07 04:18:03 +0000
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2017-07-07 04:18:03 +0000
commit9daed6795224ef93719db66b71098bb7ac1a30ec (patch)
tree0d1c01b6c368cdf9ae65b496958868f8d5ef0711 /unsupported/test/cxx11_tensor_trace.cpp
parent6795512e5942b5fd1829f776fde6611a7405b5bf (diff)
Merged in tntnatbry/eigen (pull request PR-319)
Tensor Trace op
Diffstat (limited to 'unsupported/test/cxx11_tensor_trace.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_trace.cpp171
1 files changed, 171 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_trace.cpp b/unsupported/test/cxx11_tensor_trace.cpp
new file mode 100644
index 000000000..340d1211c
--- /dev/null
+++ b/unsupported/test/cxx11_tensor_trace.cpp
@@ -0,0 +1,171 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#include "main.h"
+
+#include <Eigen/CXX11/Tensor>
+
+using Eigen::Tensor;
+using Eigen::array;
+
+template <int DataLayout>
+static void test_0D_trace() {
+ Tensor<float, 0, DataLayout> tensor;
+ tensor.setRandom();
+ array<ptrdiff_t, 0> dims;
+ Tensor<float, 0, DataLayout> result = tensor.trace(dims);
+ VERIFY_IS_EQUAL(result(), tensor());
+}
+
+
+template <int DataLayout>
+static void test_all_dimensions_trace() {
+ Tensor<float, 3, DataLayout> tensor1(5, 5, 5);
+ tensor1.setRandom();
+ Tensor<float, 0, DataLayout> result1 = tensor1.trace();
+ VERIFY_IS_EQUAL(result1.rank(), 0);
+ float sum = 0.0f;
+ for (int i = 0; i < 5; ++i) {
+ sum += tensor1(i, i, i);
+ }
+ VERIFY_IS_EQUAL(result1(), sum);
+
+ Tensor<float, 5, DataLayout> tensor2(7, 7, 7, 7, 7);
+ array<ptrdiff_t, 5> dims({{2, 1, 0, 3, 4}});
+ Tensor<float, 0, DataLayout> result2 = tensor2.trace(dims);
+ VERIFY_IS_EQUAL(result2.rank(), 0);
+ sum = 0.0f;
+ for (int i = 0; i < 7; ++i) {
+ sum += tensor2(i, i, i, i, i);
+ }
+ VERIFY_IS_EQUAL(result2(), sum);
+}
+
+
+template <int DataLayout>
+static void test_simple_trace() {
+ Tensor<float, 3, DataLayout> tensor1(3, 5, 3);
+ tensor1.setRandom();
+ array<ptrdiff_t, 2> dims1({{0, 2}});
+ Tensor<float, 1, DataLayout> result1 = tensor1.trace(dims1);
+ VERIFY_IS_EQUAL(result1.rank(), 1);
+ VERIFY_IS_EQUAL(result1.dimension(0), 5);
+ float sum = 0.0f;
+ for (int i = 0; i < 5; ++i) {
+ sum = 0.0f;
+ for (int j = 0; j < 3; ++j) {
+ sum += tensor1(j, i, j);
+ }
+ VERIFY_IS_EQUAL(result1(i), sum);
+ }
+
+ Tensor<float, 4, DataLayout> tensor2(5, 5, 7, 7);
+ tensor2.setRandom();
+ array<ptrdiff_t, 2> dims2({{2, 3}});
+ Tensor<float, 2, DataLayout> result2 = tensor2.trace(dims2);
+ VERIFY_IS_EQUAL(result2.rank(), 2);
+ VERIFY_IS_EQUAL(result2.dimension(0), 5);
+ VERIFY_IS_EQUAL(result2.dimension(1), 5);
+ for (int i = 0; i < 5; ++i) {
+ for (int j = 0; j < 5; ++j) {
+ sum = 0.0f;
+ for (int k = 0; k < 7; ++k) {
+ sum += tensor2(i, j, k, k);
+ }
+ VERIFY_IS_EQUAL(result2(i, j), sum);
+ }
+ }
+
+ array<ptrdiff_t, 2> dims3({{1, 0}});
+ Tensor<float, 2, DataLayout> result3 = tensor2.trace(dims3);
+ VERIFY_IS_EQUAL(result3.rank(), 2);
+ VERIFY_IS_EQUAL(result3.dimension(0), 7);
+ VERIFY_IS_EQUAL(result3.dimension(1), 7);
+ for (int i = 0; i < 7; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ sum = 0.0f;
+ for (int k = 0; k < 5; ++k) {
+ sum += tensor2(k, k, i, j);
+ }
+ VERIFY_IS_EQUAL(result3(i, j), sum);
+ }
+ }
+
+ Tensor<float, 5, DataLayout> tensor3(3, 7, 3, 7, 3);
+ tensor3.setRandom();
+ array<ptrdiff_t, 3> dims4({{0, 2, 4}});
+ Tensor<float, 2, DataLayout> result4 = tensor3.trace(dims4);
+ VERIFY_IS_EQUAL(result4.rank(), 2);
+ VERIFY_IS_EQUAL(result4.dimension(0), 7);
+ VERIFY_IS_EQUAL(result4.dimension(1), 7);
+ for (int i = 0; i < 7; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ sum = 0.0f;
+ for (int k = 0; k < 3; ++k) {
+ sum += tensor3(k, i, k, j, k);
+ }
+ VERIFY_IS_EQUAL(result4(i, j), sum);
+ }
+ }
+
+ Tensor<float, 5, DataLayout> tensor4(3, 7, 4, 7, 5);
+ tensor4.setRandom();
+ array<ptrdiff_t, 2> dims5({{1, 3}});
+ Tensor<float, 3, DataLayout> result5 = tensor4.trace(dims5);
+ VERIFY_IS_EQUAL(result5.rank(), 3);
+ VERIFY_IS_EQUAL(result5.dimension(0), 3);
+ VERIFY_IS_EQUAL(result5.dimension(1), 4);
+ VERIFY_IS_EQUAL(result5.dimension(2), 5);
+ for (int i = 0; i < 3; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 5; ++k) {
+ sum = 0.0f;
+ for (int l = 0; l < 7; ++l) {
+ sum += tensor4(i, l, j, l, k);
+ }
+ VERIFY_IS_EQUAL(result5(i, j, k), sum);
+ }
+ }
+ }
+}
+
+
+template<int DataLayout>
+static void test_trace_in_expr() {
+ Tensor<float, 4, DataLayout> tensor(2, 3, 5, 3);
+ tensor.setRandom();
+ array<ptrdiff_t, 2> dims({{1, 3}});
+ Tensor<float, 2, DataLayout> result(2, 5);
+ result = result.constant(1.0f) - tensor.trace(dims);
+ VERIFY_IS_EQUAL(result.rank(), 2);
+ VERIFY_IS_EQUAL(result.dimension(0), 2);
+ VERIFY_IS_EQUAL(result.dimension(1), 5);
+ float sum = 0.0f;
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 5; ++j) {
+ sum = 0.0f;
+ for (int k = 0; k < 3; ++k) {
+ sum += tensor(i, k, j, k);
+ }
+ VERIFY_IS_EQUAL(result(i, j), 1.0f - sum);
+ }
+ }
+}
+
+
+void test_cxx11_tensor_trace() {
+ CALL_SUBTEST(test_0D_trace<ColMajor>());
+ CALL_SUBTEST(test_0D_trace<RowMajor>());
+ CALL_SUBTEST(test_all_dimensions_trace<ColMajor>());
+ CALL_SUBTEST(test_all_dimensions_trace<RowMajor>());
+ CALL_SUBTEST(test_simple_trace<ColMajor>());
+ CALL_SUBTEST(test_simple_trace<RowMajor>());
+ CALL_SUBTEST(test_trace_in_expr<ColMajor>());
+ CALL_SUBTEST(test_trace_in_expr<RowMajor>());
+}