aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2011-10-24 09:33:24 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2011-10-24 09:33:24 +0200
commit5d43b4049dd7eb4d5e742a4441ee164bb886e6fe (patch)
tree2f44c218591ced85b41403e6e3bf8b2917dede80 /unsupported
parent70df09b76d1a13a55de1ebe6834ee359f403be89 (diff)
factorize solving with guess
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h40
-rw-r--r--unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h54
-rw-r--r--unsupported/Eigen/src/SparseExtra/Solve.h42
3 files changed, 72 insertions, 64 deletions
diff --git a/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h b/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h
index ee2ab128d..798f85da5 100644
--- a/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h
+++ b/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h
@@ -106,9 +106,6 @@ class BiCGSTAB;
namespace internal {
-template<typename CG, typename Rhs, typename Guess>
-class bicgstab_solve_retval_with_guess;
-
template< typename _MatrixType, typename _Preconditioner>
struct traits<BiCGSTAB<_MatrixType,_Preconditioner> >
{
@@ -204,19 +201,19 @@ public:
* \sa compute()
*/
template<typename Rhs,typename Guess>
- inline const internal::bicgstab_solve_retval_with_guess<BiCGSTAB, Rhs, Guess>
+ inline const internal::solve_retval_with_guess<BiCGSTAB, Rhs, Guess>
solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
{
eigen_assert(m_isInitialized && "BiCGSTAB is not initialized.");
eigen_assert(Base::rows()==b.rows()
&& "BiCGSTAB::solve(): invalid number of rows of the right hand side matrix b");
- return internal::bicgstab_solve_retval_with_guess
+ return internal::solve_retval_with_guess
<BiCGSTAB, Rhs, Guess>(*this, b.derived(), x0);
}
/** \internal */
template<typename Rhs,typename Dest>
- void _solve(const Rhs& b, Dest& x) const
+ void _solveWithGuess(const Rhs& b, Dest& x) const
{
for(int j=0; j<b.cols(); ++j)
{
@@ -231,6 +228,14 @@ public:
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
}
+ /** \internal */
+ template<typename Rhs,typename Dest>
+ void _solve(const Rhs& b, Dest& x) const
+ {
+ x.setOnes();
+ _solveWithGuess(b,x);
+ }
+
protected:
};
@@ -247,33 +252,10 @@ struct solve_retval<BiCGSTAB<_MatrixType, _Preconditioner>, Rhs>
template<typename Dest> void evalTo(Dest& dst) const
{
- dst.setOnes();
dec()._solve(rhs(),dst);
}
};
-template<typename CG, typename Rhs, typename Guess>
-class bicgstab_solve_retval_with_guess
- : public solve_retval_base<CG, Rhs>
-{
- typedef Eigen::internal::solve_retval_base<CG,Rhs> Base;
- using Base::dec;
- using Base::rhs;
- public:
- bicgstab_solve_retval_with_guess(const CG& cg, const Rhs& rhs, const Guess& guess)
- : Base(cg, rhs), m_guess(guess)
- {}
-
- template<typename Dest> void evalTo(Dest& dst) const
- {
- dst = m_guess;
- dec()._solve(rhs(), dst);
- }
- protected:
- const Guess& m_guess;
-
-};
-
}
#endif // EIGEN_BICGSTAB_H
diff --git a/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h b/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h
index 2a78337c5..ced3e310c 100644
--- a/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h
+++ b/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h
@@ -37,6 +37,7 @@ namespace internal {
* \param tol_error On input the tolerance error, on output an estimation of the relative error.
*/
template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner>
+EIGEN_DONT_INLINE
void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
const Preconditioner& precond, int& iters,
typename Dest::RealScalar& tol_error)
@@ -59,7 +60,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
VectorType z(n), tmp(n);
RealScalar absNew = internal::real(residual.dot(p)); // the square of the absolute value of r scaled by invM
RealScalar absInit = absNew; // the initial absolute value
-
+
int i = 0;
while ((i < maxIters) && (absNew > tol*tol*absInit))
{
@@ -89,9 +90,6 @@ class ConjugateGradient;
namespace internal {
-template<typename CG, typename Rhs, typename Guess>
-class conjugate_gradient_solve_retval_with_guess;
-
template< typename _MatrixType, int _UpLo, typename _Preconditioner>
struct traits<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> >
{
@@ -193,33 +191,44 @@ public:
* \sa compute()
*/
template<typename Rhs,typename Guess>
- inline const internal::conjugate_gradient_solve_retval_with_guess<ConjugateGradient, Rhs, Guess>
+ inline const internal::solve_retval_with_guess<ConjugateGradient, Rhs, Guess>
solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
{
eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
eigen_assert(Base::rows()==b.rows()
&& "ConjugateGradient::solve(): invalid number of rows of the right hand side matrix b");
- return internal::conjugate_gradient_solve_retval_with_guess
+ return internal::solve_retval_with_guess
<ConjugateGradient, Rhs, Guess>(*this, b.derived(), x0);
}
-
+
/** \internal */
template<typename Rhs,typename Dest>
- void _solve(const Rhs& b, Dest& x) const
+ void _solveWithGuess(const Rhs& b, Dest& x) const
{
+ m_iterations = Base::m_maxIterations;
+ m_error = Base::m_tolerance;
+
for(int j=0; j<b.cols(); ++j)
{
m_iterations = Base::m_maxIterations;
m_error = Base::m_tolerance;
-
+
typename Dest::ColXpr xj(x,j);
internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b.col(j), xj,
Base::m_preconditioner, m_iterations, m_error);
}
-
+
m_isInitialized = true;
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
}
+
+ /** \internal */
+ template<typename Rhs,typename Dest>
+ void _solve(const Rhs& b, Dest& x) const
+ {
+ x.setOnes();
+ _solveWithGuess(b,x);
+ }
protected:
@@ -228,7 +237,7 @@ protected:
namespace internal {
- template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs>
+template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs>
struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
: solve_retval_base<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
{
@@ -237,33 +246,10 @@ struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
template<typename Dest> void evalTo(Dest& dst) const
{
- dst.setOnes();
dec()._solve(rhs(),dst);
}
};
-template<typename CG, typename Rhs, typename Guess>
-class conjugate_gradient_solve_retval_with_guess
- : public solve_retval_base<CG, Rhs>
-{
- typedef Eigen::internal::solve_retval_base<CG,Rhs> Base;
- using Base::dec;
- using Base::rhs;
- public:
- conjugate_gradient_solve_retval_with_guess(const CG& cg, const Rhs& rhs, const Guess& guess)
- : Base(cg, rhs), m_guess(guess)
- {}
-
- template<typename Dest> void evalTo(Dest& dst) const
- {
- dst = m_guess;
- dec()._solve(rhs(), dst);
- }
- protected:
- const Guess& m_guess;
-
-};
-
}
#endif // EIGEN_CONJUGATE_GRADIENT_H
diff --git a/unsupported/Eigen/src/SparseExtra/Solve.h b/unsupported/Eigen/src/SparseExtra/Solve.h
index 19449e9de..5b6c859ae 100644
--- a/unsupported/Eigen/src/SparseExtra/Solve.h
+++ b/unsupported/Eigen/src/SparseExtra/Solve.h
@@ -76,7 +76,47 @@ template<typename _DecompositionType, typename Rhs> struct sparse_solve_retval_b
using Base::cols; \
sparse_solve_retval(const DecompositionType& dec, const Rhs& rhs) \
: Base(dec, rhs) {}
-
+
+
+
+template<typename DecompositionType, typename Rhs, typename Guess> struct solve_retval_with_guess;
+
+template<typename DecompositionType, typename Rhs, typename Guess>
+struct traits<solve_retval_with_guess<DecompositionType, Rhs, Guess> >
+{
+ typedef typename DecompositionType::MatrixType MatrixType;
+ typedef Matrix<typename Rhs::Scalar,
+ MatrixType::ColsAtCompileTime,
+ Rhs::ColsAtCompileTime,
+ Rhs::PlainObject::Options,
+ MatrixType::MaxColsAtCompileTime,
+ Rhs::MaxColsAtCompileTime> ReturnType;
+};
+
+template<typename DecompositionType, typename Rhs, typename Guess> struct solve_retval_with_guess
+ : public ReturnByValue<solve_retval_with_guess<DecompositionType, Rhs, Guess> >
+{
+ typedef typename DecompositionType::Index Index;
+
+ solve_retval_with_guess(const DecompositionType& dec, const Rhs& rhs, const Guess& guess)
+ : m_dec(dec), m_rhs(rhs), m_guess(guess)
+ {}
+
+ inline Index rows() const { return m_dec.cols(); }
+ inline Index cols() const { return m_rhs.cols(); }
+
+ template<typename Dest> inline void evalTo(Dest& dst) const
+ {
+ dst = m_guess;
+ m_dec._solveWithGuess(m_rhs,dst);
+ }
+
+ protected:
+ const DecompositionType& m_dec;
+ const typename Rhs::Nested m_rhs;
+ const typename Guess::Nested m_guess;
+};
+
} // namepsace internal
#endif // EIGEN_SPARSE_SOLVE_H