From a961d72e65fc537fe571845407b4e2ee0554bd49 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 6 Jun 2014 16:25:16 -0700 Subject: Added support for convolution and reshaping of tensors. --- unsupported/test/cxx11_tensor_convolution.cpp | 70 +++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 unsupported/test/cxx11_tensor_convolution.cpp (limited to 'unsupported/test/cxx11_tensor_convolution.cpp') diff --git a/unsupported/test/cxx11_tensor_convolution.cpp b/unsupported/test/cxx11_tensor_convolution.cpp new file mode 100644 index 000000000..95e40f64f --- /dev/null +++ b/unsupported/test/cxx11_tensor_convolution.cpp @@ -0,0 +1,70 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include + +using Eigen::Tensor; + + +static void test_evals() +{ + Tensor input(3, 3); + Tensor kernel(2); + + input.setRandom(); + kernel.setRandom(); + + Tensor result(2,3); + result.setZero(); + Eigen::array::Index, 1> dims3({0}); + + TensorEvaluator eval(input.convolve(kernel, dims3)); + eval.evalTo(result.data()); + EIGEN_STATIC_ASSERT(TensorEvaluator::NumDims==2ul, YOU_MADE_A_PROGRAMMING_MISTAKE); + VERIFY_IS_EQUAL(eval.dimensions()[0], 2); + VERIFY_IS_EQUAL(eval.dimensions()[1], 3); + + VERIFY_IS_APPROX(result(0,0), input(0,0)*kernel(0) + input(1,0)*kernel(1)); // index 0 + VERIFY_IS_APPROX(result(0,1), input(0,1)*kernel(0) + input(1,1)*kernel(1)); // index 2 + VERIFY_IS_APPROX(result(0,2), input(0,2)*kernel(0) + input(1,2)*kernel(1)); // index 4 + VERIFY_IS_APPROX(result(1,0), input(1,0)*kernel(0) + input(2,0)*kernel(1)); // index 1 + VERIFY_IS_APPROX(result(1,1), input(1,1)*kernel(0) + input(2,1)*kernel(1)); // index 3 + VERIFY_IS_APPROX(result(1,2), input(1,2)*kernel(0) + input(2,2)*kernel(1)); // index 5 +} + + +static void test_expr() +{ + Tensor input(3, 3); + Tensor kernel(2, 2); + input.setRandom(); + kernel.setRandom(); + + Tensor result(2,2); + Eigen::array dims({0, 1}); + result = input.convolve(kernel, dims); + + VERIFY_IS_APPROX(result(0,0), input(0,0)*kernel(0,0) + input(0,1)*kernel(0,1) + + input(1,0)*kernel(1,0) + input(1,1)*kernel(1,1)); + VERIFY_IS_APPROX(result(0,1), input(0,1)*kernel(0,0) + input(0,2)*kernel(0,1) + + input(1,1)*kernel(1,0) + input(1,2)*kernel(1,1)); + VERIFY_IS_APPROX(result(1,0), input(1,0)*kernel(0,0) + input(1,1)*kernel(0,1) + + input(2,0)*kernel(1,0) + input(2,1)*kernel(1,1)); + VERIFY_IS_APPROX(result(1,1), input(1,1)*kernel(0,0) + input(1,2)*kernel(0,1) + + input(2,1)*kernel(1,0) + input(2,2)*kernel(1,1)); +} + + +void test_cxx11_tensor_convolution() +{ + CALL_SUBTEST(test_evals()); + CALL_SUBTEST(test_expr()); +} -- cgit v1.2.3