aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h13
-rw-r--r--unsupported/test/cxx11_tensor_contraction.cpp20
2 files changed, 33 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 18b20b2dc..f070ba61e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -193,6 +193,19 @@ struct TensorContractionEvaluatorBase
}
}
+ // Check for duplicate axes and make sure the first index in eval_op_indices
+ // is increasing. Using O(n^2) sorting is OK since ContractDims is small
+ for (int i = 0; i < ContractDims; i++) {
+ for (int j = i + 1; j < ContractDims; j++) {
+ eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first &&
+ eval_op_indices[j].second != eval_op_indices[i].second &&
+ "contraction axes should be unique");
+ if (eval_op_indices[j].first < eval_op_indices[i].first) {
+ numext::swap(eval_op_indices[j], eval_op_indices[i]);
+ }
+ }
+ }
+
array<Index, LDims> lhs_strides;
lhs_strides[0] = 1;
for (int i = 0; i < LDims-1; ++i) {
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp
index 57ec5add7..0e16308a2 100644
--- a/unsupported/test/cxx11_tensor_contraction.cpp
+++ b/unsupported/test/cxx11_tensor_contraction.cpp
@@ -138,6 +138,26 @@ static void test_multidims()
mat1(1,0,1)*mat2(1,0,0,1) + mat1(1,1,1)*mat2(1,0,1,1));
VERIFY_IS_APPROX(mat3(1,1,1), mat1(1,0,0)*mat2(1,1,0,0) + mat1(1,1,0)*mat2(1,1,1,0) +
mat1(1,0,1)*mat2(1,1,0,1) + mat1(1,1,1)*mat2(1,1,1,1));
+
+ Tensor<float, 2, DataLayout> mat4(2, 2);
+ Tensor<float, 3, DataLayout> mat5(2, 2, 2);
+
+ mat4.setRandom();
+ mat5.setRandom();
+
+ Tensor<float, 1, DataLayout> mat6(2);
+ mat6.setZero();
+ Eigen::array<DimPair, 2> dims2({{DimPair(0, 1), DimPair(1, 0)}});
+ typedef TensorEvaluator<decltype(mat4.contract(mat5, dims2)), DefaultDevice> Evaluator2;
+ Evaluator2 eval2(mat4.contract(mat5, dims2), DefaultDevice());
+ eval2.evalTo(mat6.data());
+ EIGEN_STATIC_ASSERT(Evaluator2::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ VERIFY_IS_EQUAL(eval2.dimensions()[0], 2);
+
+ VERIFY_IS_APPROX(mat6(0), mat4(0,0)*mat5(0,0,0) + mat4(1,0)*mat5(0,1,0) +
+ mat4(0,1)*mat5(1,0,0) + mat4(1,1)*mat5(1,1,0));
+ VERIFY_IS_APPROX(mat6(1), mat4(0,0)*mat5(0,0,1) + mat4(1,0)*mat5(0,1,1) +
+ mat4(0,1)*mat5(1,0,1) + mat4(1,1)*mat5(1,1,1));
}
template<int DataLayout>