diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 13 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_contraction.cpp | 20 |
2 files changed, 33 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 18b20b2dc..f070ba61e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -193,6 +193,19 @@ struct TensorContractionEvaluatorBase } } + // Check for duplicate axes and make sure the first index in eval_op_indices + // is increasing. Using O(n^2) sorting is OK since ContractDims is small + for (int i = 0; i < ContractDims; i++) { + for (int j = i + 1; j < ContractDims; j++) { + eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first && + eval_op_indices[j].second != eval_op_indices[i].second && + "contraction axes should be unique"); + if (eval_op_indices[j].first < eval_op_indices[i].first) { + numext::swap(eval_op_indices[j], eval_op_indices[i]); + } + } + } + array<Index, LDims> lhs_strides; lhs_strides[0] = 1; for (int i = 0; i < LDims-1; ++i) { diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index 57ec5add7..0e16308a2 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -138,6 +138,26 @@ static void test_multidims() mat1(1,0,1)*mat2(1,0,0,1) + mat1(1,1,1)*mat2(1,0,1,1)); VERIFY_IS_APPROX(mat3(1,1,1), mat1(1,0,0)*mat2(1,1,0,0) + mat1(1,1,0)*mat2(1,1,1,0) + mat1(1,0,1)*mat2(1,1,0,1) + mat1(1,1,1)*mat2(1,1,1,1)); + + Tensor<float, 2, DataLayout> mat4(2, 2); + Tensor<float, 3, DataLayout> mat5(2, 2, 2); + + mat4.setRandom(); + mat5.setRandom(); + + Tensor<float, 1, DataLayout> mat6(2); + mat6.setZero(); + Eigen::array<DimPair, 2> dims2({{DimPair(0, 1), DimPair(1, 0)}}); + typedef TensorEvaluator<decltype(mat4.contract(mat5, dims2)), DefaultDevice> Evaluator2; + Evaluator2 eval2(mat4.contract(mat5, dims2), DefaultDevice()); + eval2.evalTo(mat6.data()); + EIGEN_STATIC_ASSERT(Evaluator2::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE); + VERIFY_IS_EQUAL(eval2.dimensions()[0], 2); + + VERIFY_IS_APPROX(mat6(0), mat4(0,0)*mat5(0,0,0) + mat4(1,0)*mat5(0,1,0) + + mat4(0,1)*mat5(1,0,0) + mat4(1,1)*mat5(1,1,0)); + VERIFY_IS_APPROX(mat6(1), mat4(0,0)*mat5(0,0,1) + mat4(1,0)*mat5(0,1,1) + + mat4(0,1)*mat5(1,0,1) + mat4(1,1)*mat5(1,1,1)); } template<int DataLayout> |