diff options
author | Godeffroy Valet <godeffroy.valet@m4x.org> | 2015-07-25 11:58:36 +0200 |
---|---|---|
committer | Godeffroy Valet <godeffroy.valet@m4x.org> | 2015-07-25 11:58:36 +0200 |
commit | 2195822df64c34eacc411043a197ce701ae6b135 (patch) | |
tree | b64e793502bc561b4d443f0352115d3959a6413b /unsupported/test/cxx11_tensor_contraction.cpp | |
parent | 4b3052c54d14f486c2665530f06cdeff636169e4 (diff) |
Allowed tensor contraction operation with an empty array of dimension pairs, which performs a tensor product.
Diffstat (limited to 'unsupported/test/cxx11_tensor_contraction.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_contraction.cpp | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index f4acdc504..b0d52c6cf 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -448,6 +448,31 @@ static void test_small_blocking_factors() } } +template<int DataLayout> +static void test_tensor_product() +{ + Tensor<float, 2, DataLayout> mat1(2, 3); + Tensor<float, 2, DataLayout> mat2(4, 1); + mat1.setRandom(); + mat2.setRandom(); + + Tensor<float, 4, DataLayout> result = mat1.contract(mat2, Eigen::array<DimPair, 0>{{}}); + + VERIFY_IS_EQUAL(result.dimension(0), 2); + VERIFY_IS_EQUAL(result.dimension(1), 3); + VERIFY_IS_EQUAL(result.dimension(2), 4); + VERIFY_IS_EQUAL(result.dimension(3), 1); + for (int i = 0; i < result.dimension(0); ++i) { + for (int j = 0; j < result.dimension(1); ++j) { + for (int k = 0; k < result.dimension(2); ++k) { + for (int l = 0; l < result.dimension(3); ++l) { + VERIFY_IS_APPROX(result(i, j, k, l), mat1(i, j) * mat2(k, l) ); + } + } + } + } +} + void test_cxx11_tensor_contraction() { @@ -477,4 +502,6 @@ void test_cxx11_tensor_contraction() CALL_SUBTEST(test_tensor_vector<RowMajor>()); CALL_SUBTEST(test_small_blocking_factors<ColMajor>()); CALL_SUBTEST(test_small_blocking_factors<RowMajor>()); + CALL_SUBTEST(test_tensor_product<ColMajor>()); + CALL_SUBTEST(test_tensor_product<RowMajor>()); } |