// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Benoit Steiner // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include "main.h" #include using Eigen::Tensor; template static void test_dimension_failures() { Tensor left(2, 3, 1); Tensor right(3, 3, 1); left.setRandom(); right.setRandom(); // Okay; other dimensions are equal. Tensor concatenation = left.concatenate(right, 0); // Dimension mismatches. VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1)); VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2)); // Axis > NumDims or < 0. VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3)); VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1)); } template static void test_static_dimension_failure() { Tensor left(2, 3); Tensor right(2, 3, 1); #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE // Technically compatible, but we static assert that the inputs have same // NumDims. Tensor concatenation = left.concatenate(right, 0); #endif // This can be worked around in this case. Tensor concatenation = left .reshape(Tensor::Dimensions(2, 3, 1)) .concatenate(right, 0); Tensor alternative = left // Clang compiler break with {{{}}} with an ambiguous error on copy constructor // the variadic DSize constructor added for #ifndef EIGEN_EMULATE_CXX11_META_H. // Solution: // either the code should change to // Tensor::Dimensions{{2, 3}} // or Tensor::Dimensions{Tensor::Dimensions{{2, 3}}} .concatenate(right.reshape(Tensor::Dimensions(2, 3)), 0); } template static void test_simple_concatenation() { Tensor left(2, 3, 1); Tensor right(2, 3, 1); left.setRandom(); right.setRandom(); Tensor concatenation = left.concatenate(right, 0); VERIFY_IS_EQUAL(concatenation.dimension(0), 4); VERIFY_IS_EQUAL(concatenation.dimension(1), 3); VERIFY_IS_EQUAL(concatenation.dimension(2), 1); for (int j = 0; j < 3; ++j) { for (int i = 0; i < 2; ++i) { VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); } for (int i = 2; i < 4; ++i) { VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0)); } } concatenation = left.concatenate(right, 1); VERIFY_IS_EQUAL(concatenation.dimension(0), 2); VERIFY_IS_EQUAL(concatenation.dimension(1), 6); VERIFY_IS_EQUAL(concatenation.dimension(2), 1); for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); } for (int j = 3; j < 6; ++j) { VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0)); } } concatenation = left.concatenate(right, 2); VERIFY_IS_EQUAL(concatenation.dimension(0), 2); VERIFY_IS_EQUAL(concatenation.dimension(1), 3); VERIFY_IS_EQUAL(concatenation.dimension(2), 2); for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0)); } } } // TODO(phli): Add test once we have a real vectorized implementation. // static void test_vectorized_concatenation() {} static void test_concatenation_as_lvalue() { Tensor t1(2, 3); Tensor t2(2, 3); t1.setRandom(); t2.setRandom(); Tensor result(4, 3); result.setRandom(); t1.concatenate(t2, 0) = result; for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { VERIFY_IS_EQUAL(t1(i, j), result(i, j)); VERIFY_IS_EQUAL(t2(i, j), result(i+2, j)); } } } EIGEN_DECLARE_TEST(cxx11_tensor_concatenation) { CALL_SUBTEST(test_dimension_failures()); CALL_SUBTEST(test_dimension_failures()); CALL_SUBTEST(test_static_dimension_failure()); CALL_SUBTEST(test_static_dimension_failure()); CALL_SUBTEST(test_simple_concatenation()); CALL_SUBTEST(test_simple_concatenation()); // CALL_SUBTEST(test_vectorized_concatenation()); CALL_SUBTEST(test_concatenation_as_lvalue()); }