diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_contraction.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_contraction.cpp | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index a37fcd967..2b599d30d 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -201,6 +201,37 @@ static void test_full_redux() } +static void test_contraction_of_contraction() +{ + Tensor<float, 2> t1(2, 2); + Tensor<float, 2> t2(2, 2); + Tensor<float, 2> t3(2, 2); + Tensor<float, 2> t4(2, 2); + t1.setRandom(); + t2.setRandom(); + t3.setRandom(); + t4.setRandom(); + + Eigen::array<DimPair, 1> dims({{DimPair(1, 0)}}); + auto contract1 = t1.contract(t2, dims); + auto diff = t3 - contract1; + auto contract2 = t1.contract(t4, dims); + Tensor<float, 2> result = contract2.contract(diff, dims); + VERIFY_IS_EQUAL(result.dimension(0), 2); + VERIFY_IS_EQUAL(result.dimension(1), 2); + + Eigen::Map<MatrixXf> m1(t1.data(), 2, 2); + Eigen::Map<MatrixXf> m2(t2.data(), 2, 2); + Eigen::Map<MatrixXf> m3(t3.data(), 2, 2); + Eigen::Map<MatrixXf> m4(t4.data(), 2, 2); + Eigen::MatrixXf expected = (m1 * m4) * (m3 - m1 * m2); + VERIFY_IS_APPROX(result(0, 0), expected(0, 0)); + VERIFY_IS_APPROX(result(0, 1), expected(0, 1)); + VERIFY_IS_APPROX(result(1, 0), expected(1, 0)); + VERIFY_IS_APPROX(result(1, 1), expected(1, 1)); +} + + static void test_expr() { Tensor<float, 2> mat1(2, 3); @@ -328,6 +359,7 @@ void test_cxx11_tensor_contraction() CALL_SUBTEST(test_multidims()); CALL_SUBTEST(test_holes()); CALL_SUBTEST(test_full_redux()); + CALL_SUBTEST(test_contraction_of_contraction()); CALL_SUBTEST(test_expr()); CALL_SUBTEST(test_out_of_order_contraction()); CALL_SUBTEST(test_consistency()); |