aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/IterativeLinearSolvers
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2018-10-15 23:47:46 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2018-10-15 23:47:46 +0200
commitf0fb95135dfd2d109a793fe8793b13c401f36bf4 (patch)
treec9c31808f908cdbb2b119fc4bd59addfcd077376 /Eigen/src/IterativeLinearSolvers
parent2747b98cfc39d7bd4b4dd56d4fed2adf30219509 (diff)
Iterative solvers: unify and fix handling of multiple rhs.
m_info was not properly computed and the logic was repeated in several places.
Diffstat (limited to 'Eigen/src/IterativeLinearSolvers')
-rw-r--r--Eigen/src/IterativeLinearSolvers/BiCGSTAB.h30
-rw-r--r--Eigen/src/IterativeLinearSolvers/ConjugateGradient.h25
-rw-r--r--Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h55
-rw-r--r--Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h22
4 files changed, 66 insertions, 66 deletions
diff --git a/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h b/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
index 454f46814..153acef65 100644
--- a/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
+++ b/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
@@ -191,32 +191,16 @@ public:
/** \internal */
template<typename Rhs,typename Dest>
- void _solve_with_guess_impl(const Rhs& b, Dest& x) const
+ void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
{
- bool failed = false;
- for(Index j=0; j<b.cols(); ++j)
- {
- m_iterations = Base::maxIterations();
- m_error = Base::m_tolerance;
-
- typename Dest::ColXpr xj(x,j);
- if(!internal::bicgstab(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error))
- failed = true;
- }
- m_info = failed ? NumericalIssue
+ m_iterations = Base::maxIterations();
+ m_error = Base::m_tolerance;
+
+ bool ret = internal::bicgstab(matrix(), b, x, Base::m_preconditioner, m_iterations, m_error);
+
+ m_info = (!ret) ? NumericalIssue
: m_error <= Base::m_tolerance ? Success
: NoConvergence;
- m_isInitialized = true;
- }
-
- /** \internal */
- using Base::_solve_impl;
- template<typename Rhs,typename Dest>
- void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
- {
- x.resize(this->rows(),b.cols());
- x.setZero();
- _solve_with_guess_impl(b,x);
}
protected:
diff --git a/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h b/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h
index f7ce47134..96e8b9f8a 100644
--- a/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h
+++ b/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h
@@ -195,7 +195,7 @@ public:
/** \internal */
template<typename Rhs,typename Dest>
- void _solve_with_guess_impl(const Rhs& b, Dest& x) const
+ void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
{
typedef typename Base::MatrixWrapper MatrixWrapper;
typedef typename Base::ActualMatrixType ActualMatrixType;
@@ -211,31 +211,14 @@ public:
RowMajorWrapper,
typename MatrixWrapper::template ConstSelfAdjointViewReturnType<UpLo>::Type
>::type SelfAdjointWrapper;
+
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
- for(Index j=0; j<b.cols(); ++j)
- {
- m_iterations = Base::maxIterations();
- m_error = Base::m_tolerance;
-
- typename Dest::ColXpr xj(x,j);
- RowMajorWrapper row_mat(matrix());
- internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error);
- }
-
- m_isInitialized = true;
+ RowMajorWrapper row_mat(matrix());
+ internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b, x, Base::m_preconditioner, m_iterations, m_error);
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
}
-
- /** \internal */
- using Base::_solve_impl;
- template<typename Rhs,typename Dest>
- void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
- {
- x.setZero();
- _solve_with_guess_impl(b.derived(),x);
- }
protected:
diff --git a/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h b/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
index bfeee71cd..9d08e6d11 100644
--- a/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
+++ b/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
@@ -331,7 +331,7 @@ public:
/** \internal */
template<typename Rhs, typename DestDerived>
- void _solve_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
+ void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
{
eigen_assert(rows()==b.rows());
@@ -344,15 +344,66 @@ public:
// We do not directly fill dest because sparse expressions have to be free of aliasing issue.
// For non square least-square problems, b and dest might not have the same size whereas they might alias each-other.
typename DestDerived::PlainObject tmp(cols(),rhsCols);
+ ComputationInfo global_info = Success;
for(Index k=0; k<rhsCols; ++k)
{
tb = b.col(k);
- tx = derived().solve(tb);
+ tx = dest.col(k);
+ derived()._solve_vector_with_guess_impl(tb,tx);
tmp.col(k) = tx.sparseView(0);
+
+ // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
+ // we need to restore it to the worst value.
+ if(m_info==NumericalIssue)
+ global_info = NumericalIssue;
+ else if(m_info==NoConvergence)
+ global_info = NoConvergence;
}
+ m_info = global_info;
dest.swap(tmp);
}
+ template<typename Rhs, typename DestDerived>
+ typename internal::enable_if<Rhs::ColsAtCompileTime!=1 && DestDerived::ColsAtCompileTime!=1>::type
+ _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &aDest) const
+ {
+ eigen_assert(rows()==b.rows());
+
+ Index rhsCols = b.cols();
+ DestDerived& dest(aDest.derived());
+ ComputationInfo global_info = Success;
+ for(Index k=0; k<rhsCols; ++k)
+ {
+ typename DestDerived::ColXpr xk(dest,k);
+ typename Rhs::ConstColXpr bk(b,k);
+ derived()._solve_vector_with_guess_impl(bk,xk);
+
+ // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
+ // we need to restore it to the worst value.
+ if(m_info==NumericalIssue)
+ global_info = NumericalIssue;
+ else if(m_info==NoConvergence)
+ global_info = NoConvergence;
+ }
+ m_info = global_info;
+ }
+
+ template<typename Rhs, typename DestDerived>
+ typename internal::enable_if<Rhs::ColsAtCompileTime==1 || DestDerived::ColsAtCompileTime==1>::type
+ _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &dest) const
+ {
+ derived()._solve_vector_with_guess_impl(b,dest.derived());
+ }
+
+ /** \internal default initial guess = 0 */
+ template<typename Rhs,typename Dest>
+ void _solve_impl(const Rhs& b, Dest& x) const
+ {
+ x.resize(this->rows(),b.cols());
+ x.setZero();
+ derived()._solve_with_guess_impl(b,x);
+ }
+
protected:
void init()
{
diff --git a/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h b/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h
index 0aea0e099..203fd0ec6 100644
--- a/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h
+++ b/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h
@@ -182,32 +182,14 @@ public:
/** \internal */
template<typename Rhs,typename Dest>
- void _solve_with_guess_impl(const Rhs& b, Dest& x) const
+ void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
{
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
- for(Index j=0; j<b.cols(); ++j)
- {
- m_iterations = Base::maxIterations();
- m_error = Base::m_tolerance;
-
- typename Dest::ColXpr xj(x,j);
- internal::least_square_conjugate_gradient(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error);
- }
-
- m_isInitialized = true;
+ internal::least_square_conjugate_gradient(matrix(), b, x, Base::m_preconditioner, m_iterations, m_error);
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
}
-
- /** \internal */
- using Base::_solve_impl;
- template<typename Rhs,typename Dest>
- void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
- {
- x.setZero();
- _solve_with_guess_impl(b.derived(),x);
- }
};