aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_contraction.cpp
diff options
context:
space:
mode:
authorGravatar Godeffroy Valet <godeffroy.valet@m4x.org>2015-07-25 11:58:36 +0200
committerGravatar Godeffroy Valet <godeffroy.valet@m4x.org>2015-07-25 11:58:36 +0200
commit2195822df64c34eacc411043a197ce701ae6b135 (patch)
treeb64e793502bc561b4d443f0352115d3959a6413b /unsupported/test/cxx11_tensor_contraction.cpp
parent4b3052c54d14f486c2665530f06cdeff636169e4 (diff)
Allowed tensor contraction operation with an empty array of dimension pairs, which performs a tensor product.
Diffstat (limited to 'unsupported/test/cxx11_tensor_contraction.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_contraction.cpp27
1 files changed, 27 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp
index f4acdc504..b0d52c6cf 100644
--- a/unsupported/test/cxx11_tensor_contraction.cpp
+++ b/unsupported/test/cxx11_tensor_contraction.cpp
@@ -448,6 +448,31 @@ static void test_small_blocking_factors()
}
}
+template<int DataLayout>
+static void test_tensor_product()
+{
+ Tensor<float, 2, DataLayout> mat1(2, 3);
+ Tensor<float, 2, DataLayout> mat2(4, 1);
+ mat1.setRandom();
+ mat2.setRandom();
+
+ Tensor<float, 4, DataLayout> result = mat1.contract(mat2, Eigen::array<DimPair, 0>{{}});
+
+ VERIFY_IS_EQUAL(result.dimension(0), 2);
+ VERIFY_IS_EQUAL(result.dimension(1), 3);
+ VERIFY_IS_EQUAL(result.dimension(2), 4);
+ VERIFY_IS_EQUAL(result.dimension(3), 1);
+ for (int i = 0; i < result.dimension(0); ++i) {
+ for (int j = 0; j < result.dimension(1); ++j) {
+ for (int k = 0; k < result.dimension(2); ++k) {
+ for (int l = 0; l < result.dimension(3); ++l) {
+ VERIFY_IS_APPROX(result(i, j, k, l), mat1(i, j) * mat2(k, l) );
+ }
+ }
+ }
+ }
+}
+
void test_cxx11_tensor_contraction()
{
@@ -477,4 +502,6 @@ void test_cxx11_tensor_contraction()
CALL_SUBTEST(test_tensor_vector<RowMajor>());
CALL_SUBTEST(test_small_blocking_factors<ColMajor>());
CALL_SUBTEST(test_small_blocking_factors<RowMajor>());
+ CALL_SUBTEST(test_tensor_product<ColMajor>());
+ CALL_SUBTEST(test_tensor_product<RowMajor>());
}