aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2013-07-01 11:49:23 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2013-07-01 11:49:23 +0200
commit22820e950e916917cebfc8aa1633162b847b53f7 (patch)
tree65fbd5ee901334c8a1ff10739de3cf69b9637abf /Eigen
parent99bef0957b218a493d8fc1649cf052677893f201 (diff)
Improve BiCGSTAB robustness: fix a divide by zero and allow to restart with a new initial residual reference.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/IterativeLinearSolvers/BiCGSTAB.h24
1 files changed, 18 insertions, 6 deletions
diff --git a/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h b/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
index fbefb696f..048d59170 100644
--- a/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
+++ b/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
@@ -43,8 +43,9 @@ bool bicgstab(const MatrixType& mat, const Rhs& rhs, Dest& x,
VectorType r = rhs - mat * x;
VectorType r0 = r;
- RealScalar r0_sqnorm = rhs.squaredNorm();
- if(r0_sqnorm == 0)
+ RealScalar r0_sqnorm = r0.squaredNorm();
+ RealScalar rhs_sqnorm = rhs.squaredNorm();
+ if(rhs_sqnorm == 0)
{
x.setZero();
return true;
@@ -61,13 +62,22 @@ bool bicgstab(const MatrixType& mat, const Rhs& rhs, Dest& x,
RealScalar tol2 = tol*tol;
int i = 0;
+ int restarts = 0;
- while ( r.squaredNorm()/r0_sqnorm > tol2 && i<maxIters )
+ while ( r.squaredNorm()/rhs_sqnorm > tol2 && i<maxIters )
{
Scalar rho_old = rho;
rho = r0.dot(r);
- if (rho == Scalar(0)) return false; /* New search directions cannot be found */
+ if (internal::isMuchSmallerThan(rho,r0_sqnorm))
+ {
+ // The new residual vector became too orthogonal to the arbitrarily choosen direction r0
+ // Let's restart with a new r0:
+ r0 = r;
+ r0_sqnorm = rho = r0.dot(r);
+ if(restarts++ == 0)
+ i = 0;
+ }
Scalar beta = (rho/rho_old) * (alpha / w);
p = r + beta * (p - w * v);
@@ -81,12 +91,14 @@ bool bicgstab(const MatrixType& mat, const Rhs& rhs, Dest& x,
z = precond.solve(s);
t.noalias() = mat * z;
- w = t.dot(s) / t.squaredNorm();
+ w = t.squaredNorm();
+ if(w>RealScalar(0))
+ w = t.dot(s) / t.squaredNorm();
x += alpha * y + w * z;
r = s - w * t;
++i;
}
- tol_error = sqrt(r.squaredNorm()/r0_sqnorm);
+ tol_error = sqrt(r.squaredNorm()/rhs_sqnorm);
iters = i;
return true;
}