diff options
Diffstat (limited to 'Eigen/src/IterativeLinearSolvers')
4 files changed, 66 insertions, 66 deletions
diff --git a/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h b/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h index 454f46814..153acef65 100644 --- a/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +++ b/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h @@ -191,32 +191,16 @@ public: /** \internal */ template<typename Rhs,typename Dest> - void _solve_with_guess_impl(const Rhs& b, Dest& x) const + void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const { - bool failed = false; - for(Index j=0; j<b.cols(); ++j) - { - m_iterations = Base::maxIterations(); - m_error = Base::m_tolerance; - - typename Dest::ColXpr xj(x,j); - if(!internal::bicgstab(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error)) - failed = true; - } - m_info = failed ? NumericalIssue + m_iterations = Base::maxIterations(); + m_error = Base::m_tolerance; + + bool ret = internal::bicgstab(matrix(), b, x, Base::m_preconditioner, m_iterations, m_error); + + m_info = (!ret) ? NumericalIssue : m_error <= Base::m_tolerance ? Success : NoConvergence; - m_isInitialized = true; - } - - /** \internal */ - using Base::_solve_impl; - template<typename Rhs,typename Dest> - void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const - { - x.resize(this->rows(),b.cols()); - x.setZero(); - _solve_with_guess_impl(b,x); } protected: diff --git a/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h b/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h index f7ce47134..96e8b9f8a 100644 --- a/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +++ b/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h @@ -195,7 +195,7 @@ public: /** \internal */ template<typename Rhs,typename Dest> - void _solve_with_guess_impl(const Rhs& b, Dest& x) const + void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const { typedef typename Base::MatrixWrapper MatrixWrapper; typedef typename Base::ActualMatrixType ActualMatrixType; @@ -211,31 +211,14 @@ public: RowMajorWrapper, typename MatrixWrapper::template ConstSelfAdjointViewReturnType<UpLo>::Type >::type SelfAdjointWrapper; + m_iterations = Base::maxIterations(); m_error = Base::m_tolerance; - for(Index j=0; j<b.cols(); ++j) - { - m_iterations = Base::maxIterations(); - m_error = Base::m_tolerance; - - typename Dest::ColXpr xj(x,j); - RowMajorWrapper row_mat(matrix()); - internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error); - } - - m_isInitialized = true; + RowMajorWrapper row_mat(matrix()); + internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b, x, Base::m_preconditioner, m_iterations, m_error); m_info = m_error <= Base::m_tolerance ? Success : NoConvergence; } - - /** \internal */ - using Base::_solve_impl; - template<typename Rhs,typename Dest> - void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const - { - x.setZero(); - _solve_with_guess_impl(b.derived(),x); - } protected: diff --git a/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h b/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h index bfeee71cd..9d08e6d11 100644 --- a/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +++ b/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h @@ -331,7 +331,7 @@ public: /** \internal */ template<typename Rhs, typename DestDerived> - void _solve_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const + void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const { eigen_assert(rows()==b.rows()); @@ -344,15 +344,66 @@ public: // We do not directly fill dest because sparse expressions have to be free of aliasing issue. // For non square least-square problems, b and dest might not have the same size whereas they might alias each-other. typename DestDerived::PlainObject tmp(cols(),rhsCols); + ComputationInfo global_info = Success; for(Index k=0; k<rhsCols; ++k) { tb = b.col(k); - tx = derived().solve(tb); + tx = dest.col(k); + derived()._solve_vector_with_guess_impl(tb,tx); tmp.col(k) = tx.sparseView(0); + + // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column + // we need to restore it to the worst value. + if(m_info==NumericalIssue) + global_info = NumericalIssue; + else if(m_info==NoConvergence) + global_info = NoConvergence; } + m_info = global_info; dest.swap(tmp); } + template<typename Rhs, typename DestDerived> + typename internal::enable_if<Rhs::ColsAtCompileTime!=1 && DestDerived::ColsAtCompileTime!=1>::type + _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &aDest) const + { + eigen_assert(rows()==b.rows()); + + Index rhsCols = b.cols(); + DestDerived& dest(aDest.derived()); + ComputationInfo global_info = Success; + for(Index k=0; k<rhsCols; ++k) + { + typename DestDerived::ColXpr xk(dest,k); + typename Rhs::ConstColXpr bk(b,k); + derived()._solve_vector_with_guess_impl(bk,xk); + + // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column + // we need to restore it to the worst value. + if(m_info==NumericalIssue) + global_info = NumericalIssue; + else if(m_info==NoConvergence) + global_info = NoConvergence; + } + m_info = global_info; + } + + template<typename Rhs, typename DestDerived> + typename internal::enable_if<Rhs::ColsAtCompileTime==1 || DestDerived::ColsAtCompileTime==1>::type + _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &dest) const + { + derived()._solve_vector_with_guess_impl(b,dest.derived()); + } + + /** \internal default initial guess = 0 */ + template<typename Rhs,typename Dest> + void _solve_impl(const Rhs& b, Dest& x) const + { + x.resize(this->rows(),b.cols()); + x.setZero(); + derived()._solve_with_guess_impl(b,x); + } + protected: void init() { diff --git a/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h b/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h index 0aea0e099..203fd0ec6 100644 --- a/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +++ b/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h @@ -182,32 +182,14 @@ public: /** \internal */ template<typename Rhs,typename Dest> - void _solve_with_guess_impl(const Rhs& b, Dest& x) const + void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const { m_iterations = Base::maxIterations(); m_error = Base::m_tolerance; - for(Index j=0; j<b.cols(); ++j) - { - m_iterations = Base::maxIterations(); - m_error = Base::m_tolerance; - - typename Dest::ColXpr xj(x,j); - internal::least_square_conjugate_gradient(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error); - } - - m_isInitialized = true; + internal::least_square_conjugate_gradient(matrix(), b, x, Base::m_preconditioner, m_iterations, m_error); m_info = m_error <= Base::m_tolerance ? Success : NoConvergence; } - - /** \internal */ - using Base::_solve_impl; - template<typename Rhs,typename Dest> - void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const - { - x.setZero(); - _solve_with_guess_impl(b.derived(),x); - } }; |