aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2011-10-11 11:29:50 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2011-10-11 11:29:50 +0200
commit5dc845829365a82caf8c7ce487f50a1a7bcd5312 (patch)
tree34f2c1cf801467af5472ce33849e56be55b9549b
parentb94c00226f52f48f1d2e6a5060325d4327ba27ff (diff)
extend CG for multiple right hand sides
-rw-r--r--unsupported/Eigen/IterativeSolvers3
-rw-r--r--unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h18
-rw-r--r--unsupported/Eigen/src/IterativeSolvers/IterativeSolverBase.h48
3 files changed, 61 insertions, 8 deletions
diff --git a/unsupported/Eigen/IterativeSolvers b/unsupported/Eigen/IterativeSolvers
index 2a06ded1a..db4940bf3 100644
--- a/unsupported/Eigen/IterativeSolvers
+++ b/unsupported/Eigen/IterativeSolvers
@@ -25,7 +25,7 @@
#ifndef EIGEN_ITERATIVE_SOLVERS_MODULE_H
#define EIGEN_ITERATIVE_SOLVERS_MODULE_H
-#include <Eigen/Core>
+#include <Eigen/Sparse>
namespace Eigen {
@@ -42,6 +42,7 @@ namespace Eigen {
//@{
#include "../../Eigen/src/misc/Solve.h"
+#include "src/SparseExtra/Solve.h"
#include "src/IterativeSolvers/IterativeSolverBase.h"
#include "src/IterativeSolvers/IterationController.h"
diff --git a/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h b/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h
index 0b0b4955b..2a78337c5 100644
--- a/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h
+++ b/unsupported/Eigen/src/IterativeSolvers/ConjugateGradient.h
@@ -45,7 +45,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
using std::abs;
typedef typename Dest::RealScalar RealScalar;
typedef typename Dest::Scalar Scalar;
- typedef Dest VectorType;
+ typedef Matrix<Scalar,Dynamic,1> VectorType;
RealScalar tol = tol_error;
int maxIters = iters;
@@ -207,11 +207,15 @@ public:
template<typename Rhs,typename Dest>
void _solve(const Rhs& b, Dest& x) const
{
- m_iterations = Base::m_maxIterations;
- m_error = Base::m_tolerance;
-
- internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b, x,
- Base::m_preconditioner, m_iterations, m_error);
+ for(int j=0; j<b.cols(); ++j)
+ {
+ m_iterations = Base::m_maxIterations;
+ m_error = Base::m_tolerance;
+
+ typename Dest::ColXpr xj(x,j);
+ internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b.col(j), xj,
+ Base::m_preconditioner, m_iterations, m_error);
+ }
m_isInitialized = true;
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
@@ -233,7 +237,7 @@ struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
template<typename Dest> void evalTo(Dest& dst) const
{
- dst.setZero();
+ dst.setOnes();
dec()._solve(rhs(),dst);
}
};
diff --git a/unsupported/Eigen/src/IterativeSolvers/IterativeSolverBase.h b/unsupported/Eigen/src/IterativeSolvers/IterativeSolverBase.h
index b79b6fa22..5e8bbd3e2 100644
--- a/unsupported/Eigen/src/IterativeSolvers/IterativeSolverBase.h
+++ b/unsupported/Eigen/src/IterativeSolvers/IterativeSolverBase.h
@@ -146,6 +146,20 @@ public:
&& "IterativeSolverBase::solve(): invalid number of rows of the right hand side matrix b");
return internal::solve_retval<Derived, Rhs>(derived(), b.derived());
}
+
+ /** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A.
+ *
+ * \sa compute()
+ */
+ template<typename Rhs>
+ inline const internal::sparse_solve_retval<IterativeSolverBase, Rhs>
+ solve(const SparseMatrixBase<Rhs>& b) const
+ {
+ eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
+ eigen_assert(rows()==b.rows()
+ && "IterativeSolverBase::solve(): invalid number of rows of the right hand side matrix b");
+ return internal::sparse_solve_retval<IterativeSolverBase, Rhs>(*this, b.derived());
+ }
/** \returns Success if the iterations converged, and NoConvergence otherwise. */
ComputationInfo info() const
@@ -153,6 +167,24 @@ public:
eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
return m_info;
}
+
+ /** \internal */
+ template<typename Rhs, typename DestScalar, int DestOptions, typename DestIndex>
+ void _solve_sparse(const Rhs& b, SparseMatrix<DestScalar,DestOptions,DestIndex> &dest) const
+ {
+ eigen_assert(rows()==b.rows());
+
+ int rhsCols = b.cols();
+ int size = b.rows();
+ Eigen::Matrix<DestScalar,Dynamic,1> tb(size);
+ Eigen::Matrix<DestScalar,Dynamic,1> tx(size);
+ for(int k=0; k<rhsCols; ++k)
+ {
+ tb = b.col(k);
+ tx = derived().solve(tb);
+ dest.col(k) = tx.sparseView(0);
+ }
+ }
protected:
void init()
@@ -173,5 +205,21 @@ protected:
mutable bool m_isInitialized;
};
+namespace internal {
+
+template<typename Derived, typename Rhs>
+struct sparse_solve_retval<IterativeSolverBase<Derived>, Rhs>
+ : sparse_solve_retval_base<IterativeSolverBase<Derived>, Rhs>
+{
+ typedef IterativeSolverBase<Derived> Dec;
+ EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
+
+ template<typename Dest> void evalTo(Dest& dst) const
+ {
+ dec().derived()._solve_sparse(rhs(),dst);
+ }
+};
+
+}
#endif // EIGEN_ITERATIVE_SOLVER_BASE_H