From 2ed1838aeb6d3c70c35dbd8d545fba1e7e1c68dc Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 10 Oct 2014 16:11:27 -0700 Subject: Added support for tensor chips --- unsupported/test/cxx11_tensor_chipping.cpp | 244 +++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 unsupported/test/cxx11_tensor_chipping.cpp (limited to 'unsupported/test/cxx11_tensor_chipping.cpp') diff --git a/unsupported/test/cxx11_tensor_chipping.cpp b/unsupported/test/cxx11_tensor_chipping.cpp new file mode 100644 index 000000000..8c8a0cec2 --- /dev/null +++ b/unsupported/test/cxx11_tensor_chipping.cpp @@ -0,0 +1,244 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include + +using Eigen::Tensor; + + +static void test_simple_chip() +{ + Tensor tensor(2,3,5,7,11); + tensor.setRandom(); + + Tensor chip1; + chip1 = tensor.chip<0>(1); + VERIFY_IS_EQUAL(chip1.dimension(0), 3); + VERIFY_IS_EQUAL(chip1.dimension(1), 5); + VERIFY_IS_EQUAL(chip1.dimension(2), 7); + VERIFY_IS_EQUAL(chip1.dimension(3), 11); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + for (int l = 0; l < 11; ++l) { + VERIFY_IS_EQUAL(chip1(i,j,k,l), tensor(1,i,j,k,l)); + } + } + } + } + + Tensor chip2 = tensor.chip<1>(1); + VERIFY_IS_EQUAL(chip2.dimension(0), 2); + VERIFY_IS_EQUAL(chip2.dimension(1), 5); + VERIFY_IS_EQUAL(chip2.dimension(2), 7); + VERIFY_IS_EQUAL(chip2.dimension(3), 11); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + for (int l = 0; l < 11; ++l) { + VERIFY_IS_EQUAL(chip2(i,j,k,l), tensor(i,1,j,k,l)); + } + } + } + } + + Tensor chip3 = tensor.chip<2>(2); + VERIFY_IS_EQUAL(chip3.dimension(0), 2); + VERIFY_IS_EQUAL(chip3.dimension(1), 3); + VERIFY_IS_EQUAL(chip3.dimension(2), 7); + VERIFY_IS_EQUAL(chip3.dimension(3), 11); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + for (int l = 0; l < 11; ++l) { + VERIFY_IS_EQUAL(chip3(i,j,k,l), tensor(i,j,2,k,l)); + } + } + } + } + + Tensor chip4(tensor.chip<3>(5)); + VERIFY_IS_EQUAL(chip4.dimension(0), 2); + VERIFY_IS_EQUAL(chip4.dimension(1), 3); + VERIFY_IS_EQUAL(chip4.dimension(2), 5); + VERIFY_IS_EQUAL(chip4.dimension(3), 11); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(chip4(i,j,k,l), tensor(i,j,k,5,l)); + } + } + } + } + + Tensor chip5(tensor.chip<4>(7)); + VERIFY_IS_EQUAL(chip5.dimension(0), 2); + VERIFY_IS_EQUAL(chip5.dimension(1), 3); + VERIFY_IS_EQUAL(chip5.dimension(2), 5); + VERIFY_IS_EQUAL(chip5.dimension(3), 7); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(chip5(i,j,k,l), tensor(i,j,k,l,7)); + } + } + } + } +} + + +static void test_chip_in_expr() { + Tensor input1(2,3,5,7,11); + input1.setRandom(); + Tensor input2(3,5,7,11); + input2.setRandom(); + + Tensor result = input1.chip<0>(0) + input2; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + for (int l = 0; l < 11; ++l) { + float expected = input1(0,i,j,k,l) + input2(i,j,k,l); + VERIFY_IS_EQUAL(result(i,j,k,l), expected); + } + } + } + } + + Tensor input3(3,7,11); + input3.setRandom(); + Tensor result2 = input1.chip<0>(0).chip<1>(2) + input3; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 11; ++k) { + float expected = input1(0,i,2,j,k) + input3(i,j,k); + VERIFY_IS_EQUAL(result2(i,j,k), expected); + } + } + } +} + + +static void test_chip_as_lvalue() +{ + Tensor input1(2,3,5,7,11); + input1.setRandom(); + + Tensor input2(3,5,7,11); + input2.setRandom(); + Tensor tensor = input1; + tensor.chip<0>(1) = input2; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + for (int m = 0; m < 11; ++m) { + if (i != 1) { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m)); + } else { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input2(j,k,l,m)); + } + } + } + } + } + } + + Tensor input3(2,5,7,11); + input3.setRandom(); + tensor = input1; + tensor.chip<1>(1) = input3; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + for (int m = 0; m < 11; ++m) { + if (j != 1) { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m)); + } else { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input3(i,k,l,m)); + } + } + } + } + } + } + + Tensor input4(2,3,7,11); + input4.setRandom(); + tensor = input1; + tensor.chip<2>(3) = input4; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + for (int m = 0; m < 11; ++m) { + if (k != 3) { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m)); + } else { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input4(i,j,l,m)); + } + } + } + } + } + } + + Tensor input5(2,3,5,11); + input5.setRandom(); + tensor = input1; + tensor.chip<3>(4) = input5; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + for (int m = 0; m < 11; ++m) { + if (l != 4) { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m)); + } else { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input5(i,j,k,m)); + } + } + } + } + } + } + + Tensor input6(2,3,5,7); + input6.setRandom(); + tensor = input1; + tensor.chip<4>(5) = input6; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + for (int m = 0; m < 11; ++m) { + if (m != 5) { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m)); + } else { + VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input6(i,j,k,l)); + } + } + } + } + } + } +} + + +void test_cxx11_tensor_chipping() +{ + CALL_SUBTEST(test_simple_chip()); + CALL_SUBTEST(test_chip_in_expr()); + CALL_SUBTEST(test_chip_as_lvalue()); +} -- cgit v1.2.3