diff options
author | 2011-10-24 09:33:24 +0200 | |
---|---|---|
committer | 2011-10-24 09:33:24 +0200 | |
commit | 5d43b4049dd7eb4d5e742a4441ee164bb886e6fe (patch) | |
tree | 2f44c218591ced85b41403e6e3bf8b2917dede80 /unsupported | |
parent | 70df09b76d1a13a55de1ebe6834ee359f403be89 (diff) |
factorize solving with guess
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h | 40 | ||||
-rw-r--r-- | unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h | 54 | ||||
-rw-r--r-- | unsupported/Eigen/src/SparseExtra/Solve.h | 42 |
3 files changed, 72 insertions, 64 deletions
diff --git a/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h b/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h index ee2ab128d..798f85da5 100644 --- a/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h +++ b/unsupported/Eigen/src/IterativeSolvers/BiCGSTAB.h @@ -106,9 +106,6 @@ class BiCGSTAB; namespace internal { -template<typename CG, typename Rhs, typename Guess> -class bicgstab_solve_retval_with_guess; - template< typename _MatrixType, typename _Preconditioner> struct traits<BiCGSTAB<_MatrixType,_Preconditioner> > { @@ -204,19 +201,19 @@ public: * \sa compute() */ template<typename Rhs,typename Guess> - inline const internal::bicgstab_solve_retval_with_guess<BiCGSTAB, Rhs, Guess> + inline const internal::solve_retval_with_guess<BiCGSTAB, Rhs, Guess> solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const { eigen_assert(m_isInitialized && "BiCGSTAB is not initialized."); eigen_assert(Base::rows()==b.rows() && "BiCGSTAB::solve(): invalid number of rows of the right hand side matrix b"); - return internal::bicgstab_solve_retval_with_guess + return internal::solve_retval_with_guess <BiCGSTAB, Rhs, Guess>(*this, b.derived(), x0); } /** \internal */ template<typename Rhs,typename Dest> - void _solve(const Rhs& b, Dest& x) const + void _solveWithGuess(const Rhs& b, Dest& x) const { for(int j=0; j<b.cols(); ++j) { @@ -231,6 +228,14 @@ public: m_info = m_error <= Base::m_tolerance ? Success : NoConvergence; } + /** \internal */ + template<typename Rhs,typename Dest> + void _solve(const Rhs& b, Dest& x) const + { + x.setOnes(); + _solveWithGuess(b,x); + } + protected: }; @@ -247,33 +252,10 @@ struct solve_retval<BiCGSTAB<_MatrixType, _Preconditioner>, Rhs> template<typename Dest> void evalTo(Dest& dst) const { - dst.setOnes(); dec()._solve(rhs(),dst); } }; -template<typename CG, typename Rhs, typename Guess> -class bicgstab_solve_retval_with_guess - : public solve_retval_base<CG, Rhs> -{ - typedef Eigen::internal::solve_retval_base<CG,Rhs> Base; - using Base::dec; - using Base::rhs; - public: - bicgstab_solve_retval_with_guess(const CG& cg, const Rhs& rhs, const Guess& guess) - : Base(cg, rhs), m_guess(guess) - {} - - template<typename Dest> void evalTo(Dest& dst) const - { - dst = m_guess; - dec()._solve(rhs(), dst); - } - protected: - const Guess& m_guess; - -}; - } #endif // EIGEN_BICGSTAB_H diff --git a/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h b/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h index 2a78337c5..ced3e310c 100644 --- a/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h +++ b/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h @@ -37,6 +37,7 @@ namespace internal { * \param tol_error On input the tolerance error, on output an estimation of the relative error. */ template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner> +EIGEN_DONT_INLINE void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, const Preconditioner& precond, int& iters, typename Dest::RealScalar& tol_error) @@ -59,7 +60,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, VectorType z(n), tmp(n); RealScalar absNew = internal::real(residual.dot(p)); // the square of the absolute value of r scaled by invM RealScalar absInit = absNew; // the initial absolute value - + int i = 0; while ((i < maxIters) && (absNew > tol*tol*absInit)) { @@ -89,9 +90,6 @@ class ConjugateGradient; namespace internal { -template<typename CG, typename Rhs, typename Guess> -class conjugate_gradient_solve_retval_with_guess; - template< typename _MatrixType, int _UpLo, typename _Preconditioner> struct traits<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> > { @@ -193,33 +191,44 @@ public: * \sa compute() */ template<typename Rhs,typename Guess> - inline const internal::conjugate_gradient_solve_retval_with_guess<ConjugateGradient, Rhs, Guess> + inline const internal::solve_retval_with_guess<ConjugateGradient, Rhs, Guess> solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const { eigen_assert(m_isInitialized && "ConjugateGradient is not initialized."); eigen_assert(Base::rows()==b.rows() && "ConjugateGradient::solve(): invalid number of rows of the right hand side matrix b"); - return internal::conjugate_gradient_solve_retval_with_guess + return internal::solve_retval_with_guess <ConjugateGradient, Rhs, Guess>(*this, b.derived(), x0); } - + /** \internal */ template<typename Rhs,typename Dest> - void _solve(const Rhs& b, Dest& x) const + void _solveWithGuess(const Rhs& b, Dest& x) const { + m_iterations = Base::m_maxIterations; + m_error = Base::m_tolerance; + for(int j=0; j<b.cols(); ++j) { m_iterations = Base::m_maxIterations; m_error = Base::m_tolerance; - + typename Dest::ColXpr xj(x,j); internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error); } - + m_isInitialized = true; m_info = m_error <= Base::m_tolerance ? Success : NoConvergence; } + + /** \internal */ + template<typename Rhs,typename Dest> + void _solve(const Rhs& b, Dest& x) const + { + x.setOnes(); + _solveWithGuess(b,x); + } protected: @@ -228,7 +237,7 @@ protected: namespace internal { - template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs> +template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs> struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs> : solve_retval_base<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs> { @@ -237,33 +246,10 @@ struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs> template<typename Dest> void evalTo(Dest& dst) const { - dst.setOnes(); dec()._solve(rhs(),dst); } }; -template<typename CG, typename Rhs, typename Guess> -class conjugate_gradient_solve_retval_with_guess - : public solve_retval_base<CG, Rhs> -{ - typedef Eigen::internal::solve_retval_base<CG,Rhs> Base; - using Base::dec; - using Base::rhs; - public: - conjugate_gradient_solve_retval_with_guess(const CG& cg, const Rhs& rhs, const Guess& guess) - : Base(cg, rhs), m_guess(guess) - {} - - template<typename Dest> void evalTo(Dest& dst) const - { - dst = m_guess; - dec()._solve(rhs(), dst); - } - protected: - const Guess& m_guess; - -}; - } #endif // EIGEN_CONJUGATE_GRADIENT_H diff --git a/unsupported/Eigen/src/SparseExtra/Solve.h b/unsupported/Eigen/src/SparseExtra/Solve.h index 19449e9de..5b6c859ae 100644 --- a/unsupported/Eigen/src/SparseExtra/Solve.h +++ b/unsupported/Eigen/src/SparseExtra/Solve.h @@ -76,7 +76,47 @@ template<typename _DecompositionType, typename Rhs> struct sparse_solve_retval_b using Base::cols; \ sparse_solve_retval(const DecompositionType& dec, const Rhs& rhs) \ : Base(dec, rhs) {} - + + + +template<typename DecompositionType, typename Rhs, typename Guess> struct solve_retval_with_guess; + +template<typename DecompositionType, typename Rhs, typename Guess> +struct traits<solve_retval_with_guess<DecompositionType, Rhs, Guess> > +{ + typedef typename DecompositionType::MatrixType MatrixType; + typedef Matrix<typename Rhs::Scalar, + MatrixType::ColsAtCompileTime, + Rhs::ColsAtCompileTime, + Rhs::PlainObject::Options, + MatrixType::MaxColsAtCompileTime, + Rhs::MaxColsAtCompileTime> ReturnType; +}; + +template<typename DecompositionType, typename Rhs, typename Guess> struct solve_retval_with_guess + : public ReturnByValue<solve_retval_with_guess<DecompositionType, Rhs, Guess> > +{ + typedef typename DecompositionType::Index Index; + + solve_retval_with_guess(const DecompositionType& dec, const Rhs& rhs, const Guess& guess) + : m_dec(dec), m_rhs(rhs), m_guess(guess) + {} + + inline Index rows() const { return m_dec.cols(); } + inline Index cols() const { return m_rhs.cols(); } + + template<typename Dest> inline void evalTo(Dest& dst) const + { + dst = m_guess; + m_dec._solveWithGuess(m_rhs,dst); + } + + protected: + const DecompositionType& m_dec; + const typename Rhs::Nested m_rhs; + const typename Guess::Nested m_guess; +}; + } // namepsace internal #endif // EIGEN_SPARSE_SOLVE_H |