diff options
author | Gael Guennebaud <g.gael@free.fr> | 2014-09-18 22:08:49 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2014-09-18 22:08:49 +0200 |
commit | 2ae20d558b33653b8ef2fe17255ed171997bcf79 (patch) | |
tree | da9a126991a414ad1d4c174a1f0271835bcd0d89 /unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h | |
parent | 62bce6e5e6da71dd8d85ae229d24b9f9f13d1681 (diff) |
Update KroneckerProduct wrt evaluator changes
Diffstat (limited to 'unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h')
-rw-r--r-- | unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h | 36 |
1 files changed, 30 insertions, 6 deletions
diff --git a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h index ca66d4d89..72e25db19 100644 --- a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h +++ b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h @@ -154,16 +154,41 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const dst.resize(this->rows(), this->cols()); dst.resizeNonZeros(0); + // 1 - evaluate the operands if needed: + typedef typename internal::nested_eval<Lhs,10>::type Lhs1; + typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned; + const Lhs1 lhs1(m_A); + typedef typename internal::nested_eval<Rhs,10>::type Rhs1; + typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned; + const Rhs1 rhs1(m_B); + + // 2 - construct a SparseView for dense operands + typedef typename internal::conditional<internal::is_same<typename internal::traits<Lhs1Cleaned>::StorageKind,Sparse>::value, Lhs1, SparseView<const Lhs1Cleaned> >::type Lhs2; + typedef typename internal::remove_all<Lhs2>::type Lhs2Cleaned; + const Lhs2 lhs2(lhs1); + typedef typename internal::conditional<internal::is_same<typename internal::traits<Rhs1Cleaned>::StorageKind,Sparse>::value, Rhs1, SparseView<const Rhs1Cleaned> >::type Rhs2; + typedef typename internal::remove_all<Rhs2>::type Rhs2Cleaned; + const Rhs2 rhs2(rhs1); + + // 3 - construct respective evaluators + typedef typename internal::evaluator<Lhs2Cleaned>::type LhsEval; + LhsEval lhsEval(lhs2); + typedef typename internal::evaluator<Rhs2Cleaned>::type RhsEval; + RhsEval rhsEval(rhs2); + + typedef typename LhsEval::InnerIterator LhsInnerIterator; + typedef typename RhsEval::InnerIterator RhsInnerIterator; + // compute number of non-zeros per innervectors of dst { VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols()); for (Index kA=0; kA < m_A.outerSize(); ++kA) - for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA) + for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA) nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++; VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols()); for (Index kB=0; kB < m_B.outerSize(); ++kB) - for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB) + for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB) nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++; Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose(); @@ -174,9 +199,9 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const { for (Index kB=0; kB < m_B.outerSize(); ++kB) { - for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA) + for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA) { - for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB) + for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB) { const Index i = itA.row() * Br + itB.row(), j = itA.col() * Bc + itB.col(); @@ -201,8 +226,7 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> > Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret, Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret, MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret, - MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret, - CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost + MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret }; typedef Matrix<Scalar,Rows,Cols> ReturnType; |