aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_contraction.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-03-17 15:08:02 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-03-17 15:08:02 -0700
commitf7329619da8d493fecc30e2a5d44bc3a672741a3 (patch)
treef32283b04dba21d50a724de18b160fb46cccc804 /unsupported/test/cxx11_tensor_contraction.cpp
parent46aa9772fcb92f6be5e90a37f4e585e670220348 (diff)
Fix bug in tensor contraction. The code assumes that contraction axis indices for the LHS (after possibly swapping to ColMajor!) is increasing. Explicitly sort the contraction axis pairs to make it so.
Diffstat (limited to 'unsupported/test/cxx11_tensor_contraction.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_contraction.cpp20
1 files changed, 20 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp
index 57ec5add7..0e16308a2 100644
--- a/unsupported/test/cxx11_tensor_contraction.cpp
+++ b/unsupported/test/cxx11_tensor_contraction.cpp
@@ -138,6 +138,26 @@ static void test_multidims()
mat1(1,0,1)*mat2(1,0,0,1) + mat1(1,1,1)*mat2(1,0,1,1));
VERIFY_IS_APPROX(mat3(1,1,1), mat1(1,0,0)*mat2(1,1,0,0) + mat1(1,1,0)*mat2(1,1,1,0) +
mat1(1,0,1)*mat2(1,1,0,1) + mat1(1,1,1)*mat2(1,1,1,1));
+
+ Tensor<float, 2, DataLayout> mat4(2, 2);
+ Tensor<float, 3, DataLayout> mat5(2, 2, 2);
+
+ mat4.setRandom();
+ mat5.setRandom();
+
+ Tensor<float, 1, DataLayout> mat6(2);
+ mat6.setZero();
+ Eigen::array<DimPair, 2> dims2({{DimPair(0, 1), DimPair(1, 0)}});
+ typedef TensorEvaluator<decltype(mat4.contract(mat5, dims2)), DefaultDevice> Evaluator2;
+ Evaluator2 eval2(mat4.contract(mat5, dims2), DefaultDevice());
+ eval2.evalTo(mat6.data());
+ EIGEN_STATIC_ASSERT(Evaluator2::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ VERIFY_IS_EQUAL(eval2.dimensions()[0], 2);
+
+ VERIFY_IS_APPROX(mat6(0), mat4(0,0)*mat5(0,0,0) + mat4(1,0)*mat5(0,1,0) +
+ mat4(0,1)*mat5(1,0,0) + mat4(1,1)*mat5(1,1,0));
+ VERIFY_IS_APPROX(mat6(1), mat4(0,0)*mat5(0,0,1) + mat4(1,0)*mat5(0,1,1) +
+ mat4(0,1)*mat5(1,0,1) + mat4(1,1)*mat5(1,1,1));
}
template<int DataLayout>