aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Chen-Pang He <jdh8@ms63.hinet.net>2012-09-30 16:30:18 +0800
committerGravatar Chen-Pang He <jdh8@ms63.hinet.net>2012-09-30 16:30:18 +0800
commiteb33d307af8cda6876b4eb334eaf258fbbfc8bff (patch)
treed89143cea8248ad449496ff0d471834590bca740
parent332eb36436ff4d8f1a6f31749c49f53205e553bd (diff)
Avoid Schur decomposition on (quasi-)triangular matrices. (Huge speed up!)
-rw-r--r--unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h5
-rw-r--r--unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h42
2 files changed, 13 insertions, 34 deletions
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h
index 6ec870d3e..166393f00 100644
--- a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h
+++ b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h
@@ -129,7 +129,7 @@ void MatrixLogarithmAtomic<MatrixType>::computeBig(const MatrixType& A, MatrixTy
int numberOfSquareRoots = 0;
int numberOfExtraSquareRoots = 0;
int degree;
- MatrixType T = A;
+ MatrixType T = A, sqrtT;
const RealScalar maxNormForPade = maxPadeDegree<= 5? 5.3149729967117310e-1: // single precision
maxPadeDegree<= 7? 2.6429608311114350e-1: // double precision
maxPadeDegree<= 8? 2.32777776523703892094e-1L: // extended precision
@@ -145,9 +145,8 @@ void MatrixLogarithmAtomic<MatrixType>::computeBig(const MatrixType& A, MatrixTy
break;
++numberOfExtraSquareRoots;
}
- MatrixType sqrtT;
MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
- T = sqrtT;
+ T = sqrtT.template triangularView<Upper>();
++numberOfSquareRoots;
}
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
index abbf64096..c1a04036b 100644
--- a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
+++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
@@ -79,18 +79,9 @@ template <typename MatrixType>
template <typename ResultType>
void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
{
- // Compute Schur decomposition of m_A
- const RealSchur<MatrixType> schurOfA(m_A);
- const MatrixType& T = schurOfA.matrixT();
- const MatrixType& U = schurOfA.matrixU();
-
- // Compute square root of T
- MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
- computeDiagonalPartOfSqrt(sqrtT, T);
- computeOffDiagonalPartOfSqrt(sqrtT, T);
-
- // Compute square root of m_A
- result = U * sqrtT * U.adjoint();
+ result.resize(m_A.rows(), m_A.cols());
+ computeDiagonalPartOfSqrt(result, m_A);
+ computeOffDiagonalPartOfSqrt(result, m_A);
}
// pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
@@ -291,17 +282,13 @@ template <typename ResultType>
void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
{
using std::sqrt;
- // Compute Schur decomposition of m_A
- const ComplexSchur<MatrixType> schurOfA(m_A);
- const MatrixType& T = schurOfA.matrixT();
- const MatrixType& U = schurOfA.matrixU();
- // Compute square root of T and store it in upper triangular part of result
+ // Compute square root of m_A and store it in upper triangular part of result
// This uses that the square root of triangular matrices can be computed directly.
result.resize(m_A.rows(), m_A.cols());
typedef typename MatrixType::Index Index;
for (Index i = 0; i < m_A.rows(); i++) {
- result.coeffRef(i,i) = sqrt(T.coeff(i,i));
+ result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
}
for (Index j = 1; j < m_A.cols(); j++) {
for (Index i = j-1; i >= 0; i--) {
@@ -309,14 +296,9 @@ void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
// if i = j-1, then segment has length 0 so tmp = 0
Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
// denominator may be zero if original matrix is singular
- result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
+ result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
}
}
-
- // Compute square root of m_A as U * result * U.adjoint()
- MatrixType tmp;
- tmp.noalias() = U * result.template triangularView<Upper>();
- result.noalias() = tmp * U.adjoint();
}
@@ -373,9 +355,8 @@ class MatrixSquareRoot<MatrixType, 0>
const MatrixType& U = schurOfA.matrixU();
// Compute square root of T
- MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
- MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
- tmp.compute(sqrtT);
+ MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
+ MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT);
// Compute square root of m_A
result = U * sqrtT * U.adjoint();
@@ -407,12 +388,11 @@ class MatrixSquareRoot<MatrixType, 1>
const MatrixType& U = schurOfA.matrixU();
// Compute square root of T
- MatrixSquareRootTriangular<MatrixType> tmp(T);
- MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
- tmp.compute(sqrtT);
+ MatrixType sqrtT;
+ MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
// Compute square root of m_A
- result = U * sqrtT * U.adjoint();
+ result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
}
private: