diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-05-05 08:37:47 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-05-05 08:37:47 -0700 |
commit | 06d774bf5865cbecfde868b2554c177d95988552 (patch) | |
tree | a44f07223de04796cfa47dd5b3b4b2e98737c5f0 /unsupported | |
parent | b300a84989cd13a779ead980c3197c2599b15f14 (diff) |
Updated the contraction code to ensure that full contraction return a tensor of rank 0
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 26 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_contraction.cpp | 9 |
2 files changed, 11 insertions, 24 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 6f113b903..9d0d432ee 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -37,7 +37,7 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > typedef typename remove_reference<RhsNested>::type _RhsNested; // From NumDims below. - static const int NumDimensions = max_n_1<traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value>::size; + static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value; static const int Layout = traits<LhsXprType>::Layout; enum { @@ -65,7 +65,7 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, typedef Device_ Device; // From NumDims below. - static const int NumDimensions = max_n_1<traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value>::size; + static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value; }; } // end namespace internal @@ -140,7 +140,7 @@ struct TensorContractionEvaluatorBase static const int RDims = internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; static const int ContractDims = internal::array_size<Indices>::value; - static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size; + static const int NumDims = LDims + RDims - 2 * ContractDims; typedef array<Index, ContractDims> contract_t; typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t; @@ -218,11 +218,9 @@ struct TensorContractionEvaluatorBase rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i]; } - m_i_strides[0] = 1; - m_j_strides[0] = 1; - if(ContractDims) { - m_k_strides[0] = 1; - } + if (m_i_strides.size() > 0) m_i_strides[0] = 1; + if (m_j_strides.size() > 0) m_j_strides[0] = 1; + if (m_k_strides.size() > 0) m_k_strides[0] = 1; m_i_size = 1; m_j_size = 1; @@ -318,11 +316,6 @@ struct TensorContractionEvaluatorBase } } - // Scalar case. We represent the result as a 1d tensor of size 1. - if (LDims + RDims == 2 * ContractDims) { - m_dimensions[0] = 1; - } - // If the layout is RowMajor, we need to reverse the m_dimensions if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) { for (int i = 0, j = NumDims - 1; i < j; i++, j--) { @@ -607,15 +600,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT static const int ContractDims = internal::array_size<Indices>::value; typedef array<Index, ContractDims> contract_t; - typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t; - typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t; + typedef array<Index, LDims - ContractDims> left_nocontract_t; + typedef array<Index, RDims - ContractDims> right_nocontract_t; - static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size; + static const int NumDims = LDims + RDims - 2 * ContractDims; // Could we use NumDimensions here? typedef DSizes<Index, NumDims> Dimensions; - EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) { } diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index 0e16308a2..73623b2ed 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -87,19 +87,14 @@ static void test_scalar() vec1.setRandom(); vec2.setRandom(); - Tensor<float, 1, DataLayout> scalar(1); - scalar.setZero(); Eigen::array<DimPair, 1> dims = {{DimPair(0, 0)}}; - typedef TensorEvaluator<decltype(vec1.contract(vec2, dims)), DefaultDevice> Evaluator; - Evaluator eval(vec1.contract(vec2, dims), DefaultDevice()); - eval.evalTo(scalar.data()); - EIGEN_STATIC_ASSERT(Evaluator::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE); + Tensor<float, 0, DataLayout> scalar = vec1.contract(vec2, dims); float expected = 0.0f; for (int i = 0; i < 6; ++i) { expected += vec1(i) * vec2(i); } - VERIFY_IS_APPROX(scalar(0), expected); + VERIFY_IS_APPROX(scalar(), expected); } template<int DataLayout> |