aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_contraction.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-05 10:49:34 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-05 10:49:34 -0700
commit8998f4099e20ebc80db0aba2582301cd48d31c5a (patch)
tree18fb2111a71d612cf9e31de79e1c4b7250501fdf /unsupported/test/cxx11_tensor_contraction.cpp
parent6fa6cdd2b988da98cbdd2b1a5fd2fd3b9d56a4b1 (diff)
Created additional tests for the tensor code.
Diffstat (limited to 'unsupported/test/cxx11_tensor_contraction.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_contraction.cpp163
1 files changed, 163 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp
new file mode 100644
index 000000000..1c89dfdd1
--- /dev/null
+++ b/unsupported/test/cxx11_tensor_contraction.cpp
@@ -0,0 +1,163 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#include "main.h"
+
+#include <Eigen/CXX11/Tensor>
+
+using Eigen::Tensor;
+
+typedef Tensor<float, 1>::DimensionPair DimPair;
+
+
+static void test_evals()
+{
+ Tensor<float, 2> mat1(2, 3);
+ Tensor<float, 2> mat2(2, 3);
+ Tensor<float, 2> mat3(3, 2);
+
+ mat1.setRandom();
+ mat2.setRandom();
+ mat3.setRandom();
+
+ Tensor<float, 2> mat4(3,3);
+ mat4.setZero();
+ Eigen::array<DimPair, 1> dims3({{DimPair(0, 0)}});
+ TensorEvaluator<decltype(mat1.contract(mat2, dims3))> eval(mat1.contract(mat2, dims3));
+ eval.evalTo(mat4.data());
+ EIGEN_STATIC_ASSERT(TensorEvaluator<decltype(mat1.contract(mat2, dims3))>::NumDims==2ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ VERIFY_IS_EQUAL(eval.dimensions()[0], 3);
+ VERIFY_IS_EQUAL(eval.dimensions()[1], 3);
+
+ VERIFY_IS_APPROX(mat4(0,0), mat1(0,0)*mat2(0,0) + mat1(1,0)*mat2(1,0));
+ VERIFY_IS_APPROX(mat4(0,1), mat1(0,0)*mat2(0,1) + mat1(1,0)*mat2(1,1));
+ VERIFY_IS_APPROX(mat4(0,2), mat1(0,0)*mat2(0,2) + mat1(1,0)*mat2(1,2));
+ VERIFY_IS_APPROX(mat4(1,0), mat1(0,1)*mat2(0,0) + mat1(1,1)*mat2(1,0));
+ VERIFY_IS_APPROX(mat4(1,1), mat1(0,1)*mat2(0,1) + mat1(1,1)*mat2(1,1));
+ VERIFY_IS_APPROX(mat4(1,2), mat1(0,1)*mat2(0,2) + mat1(1,1)*mat2(1,2));
+ VERIFY_IS_APPROX(mat4(2,0), mat1(0,2)*mat2(0,0) + mat1(1,2)*mat2(1,0));
+ VERIFY_IS_APPROX(mat4(2,1), mat1(0,2)*mat2(0,1) + mat1(1,2)*mat2(1,1));
+ VERIFY_IS_APPROX(mat4(2,2), mat1(0,2)*mat2(0,2) + mat1(1,2)*mat2(1,2));
+
+ Tensor<float, 2> mat5(2,2);
+ mat5.setZero();
+ Eigen::array<DimPair, 1> dims4({{DimPair(1, 1)}});
+ TensorEvaluator<decltype(mat1.contract(mat2, dims4))> eval2(mat1.contract(mat2, dims4));
+ eval2.evalTo(mat5.data());
+ EIGEN_STATIC_ASSERT(TensorEvaluator<decltype(mat1.contract(mat2, dims4))>::NumDims==2ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ VERIFY_IS_EQUAL(eval2.dimensions()[0], 2);
+ VERIFY_IS_EQUAL(eval2.dimensions()[1], 2);
+
+ VERIFY_IS_APPROX(mat5(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(0,1) + mat1(0,2)*mat2(0,2));
+ VERIFY_IS_APPROX(mat5(0,1), mat1(0,0)*mat2(1,0) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(1,2));
+ VERIFY_IS_APPROX(mat5(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(0,1) + mat1(1,2)*mat2(0,2));
+ VERIFY_IS_APPROX(mat5(1,1), mat1(1,0)*mat2(1,0) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(1,2));
+
+ Tensor<float, 2> mat6(2,2);
+ mat6.setZero();
+ Eigen::array<DimPair, 1> dims6({{DimPair(1, 0)}});
+ TensorEvaluator<decltype(mat1.contract(mat3, dims6))> eval3(mat1.contract(mat3, dims6));
+ eval3.evalTo(mat6.data());
+ EIGEN_STATIC_ASSERT(TensorEvaluator<decltype(mat1.contract(mat3, dims6))>::NumDims==2ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ VERIFY_IS_EQUAL(eval3.dimensions()[0], 2);
+ VERIFY_IS_EQUAL(eval3.dimensions()[1], 2);
+
+ VERIFY_IS_APPROX(mat6(0,0), mat1(0,0)*mat3(0,0) + mat1(0,1)*mat3(1,0) + mat1(0,2)*mat3(2,0));
+ VERIFY_IS_APPROX(mat6(0,1), mat1(0,0)*mat3(0,1) + mat1(0,1)*mat3(1,1) + mat1(0,2)*mat3(2,1));
+ VERIFY_IS_APPROX(mat6(1,0), mat1(1,0)*mat3(0,0) + mat1(1,1)*mat3(1,0) + mat1(1,2)*mat3(2,0));
+ VERIFY_IS_APPROX(mat6(1,1), mat1(1,0)*mat3(0,1) + mat1(1,1)*mat3(1,1) + mat1(1,2)*mat3(2,1));
+}
+
+
+static void test_scalar()
+{
+ Tensor<float, 1> vec1({6});
+ Tensor<float, 1> vec2({6});
+
+ vec1.setRandom();
+ vec2.setRandom();
+
+ Tensor<float, 1> scalar(1);
+ scalar.setZero();
+ Eigen::array<DimPair, 1> dims({{DimPair(0, 0)}});
+ TensorEvaluator<decltype(vec1.contract(vec2, dims))> eval(vec1.contract(vec2, dims));
+ eval.evalTo(scalar.data());
+ EIGEN_STATIC_ASSERT(TensorEvaluator<decltype(vec1.contract(vec2, dims))>::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ float expected = 0.0f;
+ for (int i = 0; i < 6; ++i) {
+ expected += vec1(i) * vec2(i);
+ }
+ VERIFY_IS_APPROX(scalar(0), expected);
+}
+
+
+static void test_multidims()
+{
+ Tensor<float, 3> mat1(2, 2, 2);
+ Tensor<float, 4> mat2(2, 2, 2, 2);
+
+ mat1.setRandom();
+ mat2.setRandom();
+
+ Tensor<float, 3> mat3(2, 2, 2);
+ mat3.setZero();
+ Eigen::array<DimPair, 2> dims({{DimPair(1, 2), DimPair(2, 3)}});
+ TensorEvaluator<decltype(mat1.contract(mat2, dims))> eval(mat1.contract(mat2, dims));
+ eval.evalTo(mat3.data());
+ EIGEN_STATIC_ASSERT(TensorEvaluator<decltype(mat1.contract(mat2, dims))>::NumDims==3ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ VERIFY_IS_EQUAL(eval.dimensions()[0], 2);
+ VERIFY_IS_EQUAL(eval.dimensions()[1], 2);
+ VERIFY_IS_EQUAL(eval.dimensions()[2], 2);
+
+ VERIFY_IS_APPROX(mat3(0,0,0), mat1(0,0,0)*mat2(0,0,0,0) + mat1(0,1,0)*mat2(0,0,1,0) +
+ mat1(0,0,1)*mat2(0,0,0,1) + mat1(0,1,1)*mat2(0,0,1,1));
+ VERIFY_IS_APPROX(mat3(0,0,1), mat1(0,0,0)*mat2(0,1,0,0) + mat1(0,1,0)*mat2(0,1,1,0) +
+ mat1(0,0,1)*mat2(0,1,0,1) + mat1(0,1,1)*mat2(0,1,1,1));
+ VERIFY_IS_APPROX(mat3(0,1,0), mat1(0,0,0)*mat2(1,0,0,0) + mat1(0,1,0)*mat2(1,0,1,0) +
+ mat1(0,0,1)*mat2(1,0,0,1) + mat1(0,1,1)*mat2(1,0,1,1));
+ VERIFY_IS_APPROX(mat3(0,1,1), mat1(0,0,0)*mat2(1,1,0,0) + mat1(0,1,0)*mat2(1,1,1,0) +
+ mat1(0,0,1)*mat2(1,1,0,1) + mat1(0,1,1)*mat2(1,1,1,1));
+ VERIFY_IS_APPROX(mat3(1,0,0), mat1(1,0,0)*mat2(0,0,0,0) + mat1(1,1,0)*mat2(0,0,1,0) +
+ mat1(1,0,1)*mat2(0,0,0,1) + mat1(1,1,1)*mat2(0,0,1,1));
+ VERIFY_IS_APPROX(mat3(1,0,1), mat1(1,0,0)*mat2(0,1,0,0) + mat1(1,1,0)*mat2(0,1,1,0) +
+ mat1(1,0,1)*mat2(0,1,0,1) + mat1(1,1,1)*mat2(0,1,1,1));
+ VERIFY_IS_APPROX(mat3(1,1,0), mat1(1,0,0)*mat2(1,0,0,0) + mat1(1,1,0)*mat2(1,0,1,0) +
+ mat1(1,0,1)*mat2(1,0,0,1) + mat1(1,1,1)*mat2(1,0,1,1));
+ VERIFY_IS_APPROX(mat3(1,1,1), mat1(1,0,0)*mat2(1,1,0,0) + mat1(1,1,0)*mat2(1,1,1,0) +
+ mat1(1,0,1)*mat2(1,1,0,1) + mat1(1,1,1)*mat2(1,1,1,1));
+}
+
+
+static void test_expr()
+{
+ Tensor<float, 2> mat1(2, 3);
+ Tensor<float, 2> mat2(3, 2);
+ mat1.setRandom();
+ mat2.setRandom();
+
+ Tensor<float, 2> mat3(2,2);
+
+ Eigen::array<DimPair, 1> dims({{DimPair(1, 0)}});
+ mat3 = mat1.contract(mat2, dims);
+
+ VERIFY_IS_APPROX(mat3(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(1,0) + mat1(0,2)*mat2(2,0));
+ VERIFY_IS_APPROX(mat3(0,1), mat1(0,0)*mat2(0,1) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(2,1));
+ VERIFY_IS_APPROX(mat3(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(1,0) + mat1(1,2)*mat2(2,0));
+ VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
+}
+
+
+void test_cxx11_tensor_contraction()
+{
+ CALL_SUBTEST(test_evals());
+ CALL_SUBTEST(test_scalar());
+ CALL_SUBTEST(test_multidims());
+ CALL_SUBTEST(test_expr());
+}