From 7caaf6453b7b1f58d953729380d596b2d9b27835 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 1 Oct 2014 20:38:22 -0700 Subject: Added support for tensor reductions and concatenations --- unsupported/test/cxx11_tensor_concatenation.cpp | 110 ++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 unsupported/test/cxx11_tensor_concatenation.cpp (limited to 'unsupported/test/cxx11_tensor_concatenation.cpp') diff --git a/unsupported/test/cxx11_tensor_concatenation.cpp b/unsupported/test/cxx11_tensor_concatenation.cpp new file mode 100644 index 000000000..8fd4f5f80 --- /dev/null +++ b/unsupported/test/cxx11_tensor_concatenation.cpp @@ -0,0 +1,110 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include + +using Eigen::Tensor; + +static void test_dimension_failures() +{ + Tensor left(2, 3, 1); + Tensor right(3, 3, 1); + left.setRandom(); + right.setRandom(); + + // Okay; other dimensions are equal. + Tensor concatenation = left.concatenate(right, 0); + + // Dimension mismatches. + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1)); + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2)); + + // Axis > NumDims or < 0. + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3)); + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1)); +} + +static void test_static_dimension_failure() +{ + Tensor left(2, 3); + Tensor right(2, 3, 1); + +#ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE + // Technically compatible, but we static assert that the inputs have same + // NumDims. + Tensor concatenation = left.concatenate(right, 0); +#endif + + // This can be worked around in this case. + Tensor concatenation = left + .reshape(Tensor::Dimensions{{2, 3, 1}}) + .concatenate(right, 0); + Tensor alternative = left + .concatenate(right.reshape(Tensor::Dimensions{{2, 3}}), 0); +} + +static void test_simple_concatenation() +{ + Tensor left(2, 3, 1); + Tensor right(2, 3, 1); + left.setRandom(); + right.setRandom(); + + Tensor concatenation = left.concatenate(right, 0); + VERIFY_IS_EQUAL(concatenation.dimension(0), 4); + VERIFY_IS_EQUAL(concatenation.dimension(1), 3); + VERIFY_IS_EQUAL(concatenation.dimension(2), 1); + for (int j = 0; j < 3; ++j) { + for (int i = 0; i < 2; ++i) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + } + for (int i = 2; i < 4; ++i) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0)); + } + } + + concatenation = left.concatenate(right, 1); + VERIFY_IS_EQUAL(concatenation.dimension(0), 2); + VERIFY_IS_EQUAL(concatenation.dimension(1), 6); + VERIFY_IS_EQUAL(concatenation.dimension(2), 1); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + } + for (int j = 3; j < 6; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0)); + } + } + + concatenation = left.concatenate(right, 2); + VERIFY_IS_EQUAL(concatenation.dimension(0), 2); + VERIFY_IS_EQUAL(concatenation.dimension(1), 3); + VERIFY_IS_EQUAL(concatenation.dimension(2), 2); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0)); + } + } +} + + +// TODO(phli): Add test once we have a real vectorized implementation. +// static void test_vectorized_concatenation() {} + + +void test_cxx11_tensor_concatenation() +{ + CALL_SUBTEST(test_dimension_failures()); + CALL_SUBTEST(test_static_dimension_failure()); + CALL_SUBTEST(test_simple_concatenation()); + // CALL_SUBTEST(test_vectorized_concatenation()); +} -- cgit v1.2.3