aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_contraction.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-09-01 11:41:27 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-09-01 11:41:27 -0700
commitc53f783705e05c07d9f1c02ab12fdb5d57f1a7a9 (patch)
treeaae30d99e99aa985a665db973cd29922f15fcec4 /unsupported/test/cxx11_tensor_contraction.cpp
parentef54723dbe80787f743f6bfa4d11c090486ec01a (diff)
Updated the contraction code to support constant inputs.
Diffstat (limited to 'unsupported/test/cxx11_tensor_contraction.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_contraction.cpp23
1 files changed, 23 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp
index 73623b2ed..ace97057f 100644
--- a/unsupported/test/cxx11_tensor_contraction.cpp
+++ b/unsupported/test/cxx11_tensor_contraction.cpp
@@ -489,6 +489,27 @@ static void test_tensor_product()
}
+template<int DataLayout>
+static void test_const_inputs()
+{
+ Tensor<float, 2, DataLayout> in1(2, 3);
+ Tensor<float, 2, DataLayout> in2(3, 2);
+ in1.setRandom();
+ in2.setRandom();
+
+ TensorMap<Tensor<const float, 2, DataLayout> > mat1(in1.data(), 2, 3);
+ TensorMap<Tensor<const float, 2, DataLayout> > mat2(in2.data(), 3, 2);
+ Tensor<float, 2, DataLayout> mat3(2,2);
+
+ Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
+ mat3 = mat1.contract(mat2, dims);
+
+ VERIFY_IS_APPROX(mat3(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(1,0) + mat1(0,2)*mat2(2,0));
+ VERIFY_IS_APPROX(mat3(0,1), mat1(0,0)*mat2(0,1) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(2,1));
+ VERIFY_IS_APPROX(mat3(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(1,0) + mat1(1,2)*mat2(2,0));
+ VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
+}
+
void test_cxx11_tensor_contraction()
{
CALL_SUBTEST(test_evals<ColMajor>());
@@ -519,4 +540,6 @@ void test_cxx11_tensor_contraction()
CALL_SUBTEST(test_small_blocking_factors<RowMajor>());
CALL_SUBTEST(test_tensor_product<ColMajor>());
CALL_SUBTEST(test_tensor_product<RowMajor>());
+ CALL_SUBTEST(test_const_inputs<ColMajor>());
+ CALL_SUBTEST(test_const_inputs<RowMajor>());
}