aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_contraction.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-10-03 19:33:44 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-10-03 19:33:44 -0700
commit152f3218ac9b6941cf6dbc960c2d4a6d1099eb06 (patch)
tree1e6f686a98ebc338485a3fea60bf24df614acdc0 /unsupported/test/cxx11_tensor_contraction.cpp
parentaf2e5995e2ba48384024bbc8432bd6dbbebf71d2 (diff)
Improved contraction test
Diffstat (limited to 'unsupported/test/cxx11_tensor_contraction.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_contraction.cpp32
1 files changed, 32 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp
index a37fcd967..2b599d30d 100644
--- a/unsupported/test/cxx11_tensor_contraction.cpp
+++ b/unsupported/test/cxx11_tensor_contraction.cpp
@@ -201,6 +201,37 @@ static void test_full_redux()
}
+static void test_contraction_of_contraction()
+{
+ Tensor<float, 2> t1(2, 2);
+ Tensor<float, 2> t2(2, 2);
+ Tensor<float, 2> t3(2, 2);
+ Tensor<float, 2> t4(2, 2);
+ t1.setRandom();
+ t2.setRandom();
+ t3.setRandom();
+ t4.setRandom();
+
+ Eigen::array<DimPair, 1> dims({{DimPair(1, 0)}});
+ auto contract1 = t1.contract(t2, dims);
+ auto diff = t3 - contract1;
+ auto contract2 = t1.contract(t4, dims);
+ Tensor<float, 2> result = contract2.contract(diff, dims);
+ VERIFY_IS_EQUAL(result.dimension(0), 2);
+ VERIFY_IS_EQUAL(result.dimension(1), 2);
+
+ Eigen::Map<MatrixXf> m1(t1.data(), 2, 2);
+ Eigen::Map<MatrixXf> m2(t2.data(), 2, 2);
+ Eigen::Map<MatrixXf> m3(t3.data(), 2, 2);
+ Eigen::Map<MatrixXf> m4(t4.data(), 2, 2);
+ Eigen::MatrixXf expected = (m1 * m4) * (m3 - m1 * m2);
+ VERIFY_IS_APPROX(result(0, 0), expected(0, 0));
+ VERIFY_IS_APPROX(result(0, 1), expected(0, 1));
+ VERIFY_IS_APPROX(result(1, 0), expected(1, 0));
+ VERIFY_IS_APPROX(result(1, 1), expected(1, 1));
+}
+
+
static void test_expr()
{
Tensor<float, 2> mat1(2, 3);
@@ -328,6 +359,7 @@ void test_cxx11_tensor_contraction()
CALL_SUBTEST(test_multidims());
CALL_SUBTEST(test_holes());
CALL_SUBTEST(test_full_redux());
+ CALL_SUBTEST(test_contraction_of_contraction());
CALL_SUBTEST(test_expr());
CALL_SUBTEST(test_out_of_order_contraction());
CALL_SUBTEST(test_consistency());