diff options
Diffstat (limited to 'Eigen/src/SparseCore/TriangularSolver.h')
-rw-r--r-- | Eigen/src/SparseCore/TriangularSolver.h | 56 |
1 files changed, 28 insertions, 28 deletions
diff --git a/Eigen/src/SparseCore/TriangularSolver.h b/Eigen/src/SparseCore/TriangularSolver.h index dd55522a7..98062e9c6 100644 --- a/Eigen/src/SparseCore/TriangularSolver.h +++ b/Eigen/src/SparseCore/TriangularSolver.h @@ -29,8 +29,11 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor> { typedef typename Rhs::Scalar Scalar; typedef typename Lhs::Index Index; + typedef typename evaluator<Lhs>::type LhsEval; + typedef typename evaluator<Lhs>::InnerIterator LhsIterator; static void run(const Lhs& lhs, Rhs& other) { + LhsEval lhsEval(lhs); for(Index col=0 ; col<other.cols() ; ++col) { for(Index i=0; i<lhs.rows(); ++i) @@ -38,7 +41,7 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor> Scalar tmp = other.coeff(i,col); Scalar lastVal(0); Index lastIndex = 0; - for(typename Lhs::InnerIterator it(lhs, i); it; ++it) + for(LhsIterator it(lhsEval, i); it; ++it) { lastVal = it.value(); lastIndex = it.index(); @@ -64,15 +67,18 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor> { typedef typename Rhs::Scalar Scalar; typedef typename Lhs::Index Index; + typedef typename evaluator<Lhs>::type LhsEval; + typedef typename evaluator<Lhs>::InnerIterator LhsIterator; static void run(const Lhs& lhs, Rhs& other) { + LhsEval lhsEval(lhs); for(Index col=0 ; col<other.cols() ; ++col) { for(Index i=lhs.rows()-1 ; i>=0 ; --i) { Scalar tmp = other.coeff(i,col); Scalar l_ii = 0; - typename Lhs::InnerIterator it(lhs, i); + LhsIterator it(lhsEval, i); while(it && it.index()<i) ++it; if(!(Mode & UnitDiag)) @@ -88,10 +94,8 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor> tmp -= it.value() * other.coeff(it.index(),col); } - if (Mode & UnitDiag) - other.coeffRef(i,col) = tmp; - else - other.coeffRef(i,col) = tmp/l_ii; + if (Mode & UnitDiag) other.coeffRef(i,col) = tmp; + else other.coeffRef(i,col) = tmp/l_ii; } } } @@ -103,8 +107,11 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor> { typedef typename Rhs::Scalar Scalar; typedef typename Lhs::Index Index; + typedef typename evaluator<Lhs>::type LhsEval; + typedef typename evaluator<Lhs>::InnerIterator LhsIterator; static void run(const Lhs& lhs, Rhs& other) { + LhsEval lhsEval(lhs); for(Index col=0 ; col<other.cols() ; ++col) { for(Index i=0; i<lhs.cols(); ++i) @@ -112,7 +119,7 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor> Scalar& tmp = other.coeffRef(i,col); if (tmp!=Scalar(0)) // optimization when other is actually sparse { - typename Lhs::InnerIterator it(lhs, i); + LhsIterator it(lhsEval, i); while(it && it.index()<i) ++it; if(!(Mode & UnitDiag)) @@ -136,8 +143,11 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor> { typedef typename Rhs::Scalar Scalar; typedef typename Lhs::Index Index; + typedef typename evaluator<Lhs>::type LhsEval; + typedef typename evaluator<Lhs>::InnerIterator LhsIterator; static void run(const Lhs& lhs, Rhs& other) { + LhsEval lhsEval(lhs); for(Index col=0 ; col<other.cols() ; ++col) { for(Index i=lhs.cols()-1; i>=0; --i) @@ -148,13 +158,13 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor> if(!(Mode & UnitDiag)) { // TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements - typename Lhs::ReverseInnerIterator it(lhs, i); + LhsIterator it(lhsEval, i); while(it && it.index()!=i) - --it; + ++it; eigen_assert(it && it.index()==i); other.coeffRef(i,col) /= it.value(); } - typename Lhs::InnerIterator it(lhs, i); + LhsIterator it(lhsEval, i); for(; it && it.index()<i; ++it) other.coeffRef(it.index(), col) -= tmp * it.value(); } @@ -165,11 +175,11 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor> } // end namespace internal -template<typename ExpressionType,int Mode> +template<typename ExpressionType,unsigned int Mode> template<typename OtherDerived> -void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDerived>& other) const +void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(MatrixBase<OtherDerived>& other) const { - eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows()); + eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows()); eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower))); enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit }; @@ -178,22 +188,12 @@ void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDer typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy; OtherCopy otherCopy(other.derived()); - internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(m_matrix, otherCopy); + internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(derived().nestedExpression(), otherCopy); if (copy) other = otherCopy; } -template<typename ExpressionType,int Mode> -template<typename OtherDerived> -typename internal::plain_matrix_type_column_major<OtherDerived>::type -SparseTriangularView<ExpressionType,Mode>::solve(const MatrixBase<OtherDerived>& other) const -{ - typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other); - solveInPlace(res); - return res; -} - // pure sparse path namespace internal { @@ -290,11 +290,11 @@ struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor> } // end namespace internal -template<typename ExpressionType,int Mode> +template<typename ExpressionType,unsigned int Mode> template<typename OtherDerived> -void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const +void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const { - eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows()); + eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows()); eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower))); // enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit }; @@ -303,7 +303,7 @@ void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<Ot // typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy; // OtherCopy otherCopy(other.derived()); - internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(m_matrix, other.derived()); + internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(derived().nestedExpression(), other.derived()); // if (copy) // other = otherCopy; |