aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2018-10-15 23:47:46 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2018-10-15 23:47:46 +0200
commitf0fb95135dfd2d109a793fe8793b13c401f36bf4 (patch)
treec9c31808f908cdbb2b119fc4bd59addfcd077376 /Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
parent2747b98cfc39d7bd4b4dd56d4fed2adf30219509 (diff)
Iterative solvers: unify and fix handling of multiple rhs.
m_info was not properly computed and the logic was repeated in several places.
Diffstat (limited to 'Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h')
-rw-r--r--Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h55
1 files changed, 53 insertions, 2 deletions
diff --git a/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h b/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
index bfeee71cd..9d08e6d11 100644
--- a/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
+++ b/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
@@ -331,7 +331,7 @@ public:
/** \internal */
template<typename Rhs, typename DestDerived>
- void _solve_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
+ void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
{
eigen_assert(rows()==b.rows());
@@ -344,15 +344,66 @@ public:
// We do not directly fill dest because sparse expressions have to be free of aliasing issue.
// For non square least-square problems, b and dest might not have the same size whereas they might alias each-other.
typename DestDerived::PlainObject tmp(cols(),rhsCols);
+ ComputationInfo global_info = Success;
for(Index k=0; k<rhsCols; ++k)
{
tb = b.col(k);
- tx = derived().solve(tb);
+ tx = dest.col(k);
+ derived()._solve_vector_with_guess_impl(tb,tx);
tmp.col(k) = tx.sparseView(0);
+
+ // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
+ // we need to restore it to the worst value.
+ if(m_info==NumericalIssue)
+ global_info = NumericalIssue;
+ else if(m_info==NoConvergence)
+ global_info = NoConvergence;
}
+ m_info = global_info;
dest.swap(tmp);
}
+ template<typename Rhs, typename DestDerived>
+ typename internal::enable_if<Rhs::ColsAtCompileTime!=1 && DestDerived::ColsAtCompileTime!=1>::type
+ _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &aDest) const
+ {
+ eigen_assert(rows()==b.rows());
+
+ Index rhsCols = b.cols();
+ DestDerived& dest(aDest.derived());
+ ComputationInfo global_info = Success;
+ for(Index k=0; k<rhsCols; ++k)
+ {
+ typename DestDerived::ColXpr xk(dest,k);
+ typename Rhs::ConstColXpr bk(b,k);
+ derived()._solve_vector_with_guess_impl(bk,xk);
+
+ // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
+ // we need to restore it to the worst value.
+ if(m_info==NumericalIssue)
+ global_info = NumericalIssue;
+ else if(m_info==NoConvergence)
+ global_info = NoConvergence;
+ }
+ m_info = global_info;
+ }
+
+ template<typename Rhs, typename DestDerived>
+ typename internal::enable_if<Rhs::ColsAtCompileTime==1 || DestDerived::ColsAtCompileTime==1>::type
+ _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &dest) const
+ {
+ derived()._solve_vector_with_guess_impl(b,dest.derived());
+ }
+
+ /** \internal default initial guess = 0 */
+ template<typename Rhs,typename Dest>
+ void _solve_impl(const Rhs& b, Dest& x) const
+ {
+ x.resize(this->rows(),b.cols());
+ x.setZero();
+ derived()._solve_with_guess_impl(b,x);
+ }
+
protected:
void init()
{