From eb33d307af8cda6876b4eb334eaf258fbbfc8bff Mon Sep 17 00:00:00 2001 From: Chen-Pang He Date: Sun, 30 Sep 2012 16:30:18 +0800 Subject: Avoid Schur decomposition on (quasi-)triangular matrices. (Huge speed up!) --- .../Eigen/src/MatrixFunctions/MatrixLogarithm.h | 5 ++- .../Eigen/src/MatrixFunctions/MatrixSquareRoot.h | 42 ++++++---------------- 2 files changed, 13 insertions(+), 34 deletions(-) (limited to 'unsupported/Eigen/src') diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h index 6ec870d3e..166393f00 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h @@ -129,7 +129,7 @@ void MatrixLogarithmAtomic::computeBig(const MatrixType& A, MatrixTy int numberOfSquareRoots = 0; int numberOfExtraSquareRoots = 0; int degree; - MatrixType T = A; + MatrixType T = A, sqrtT; const RealScalar maxNormForPade = maxPadeDegree<= 5? 5.3149729967117310e-1: // single precision maxPadeDegree<= 7? 2.6429608311114350e-1: // double precision maxPadeDegree<= 8? 2.32777776523703892094e-1L: // extended precision @@ -145,9 +145,8 @@ void MatrixLogarithmAtomic::computeBig(const MatrixType& A, MatrixTy break; ++numberOfExtraSquareRoots; } - MatrixType sqrtT; MatrixSquareRootTriangular(T).compute(sqrtT); - T = sqrtT; + T = sqrtT.template triangularView(); ++numberOfSquareRoots; } diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h index abbf64096..c1a04036b 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h @@ -79,18 +79,9 @@ template template void MatrixSquareRootQuasiTriangular::compute(ResultType &result) { - // Compute Schur decomposition of m_A - const RealSchur schurOfA(m_A); - const MatrixType& T = schurOfA.matrixT(); - const MatrixType& U = schurOfA.matrixU(); - - // Compute square root of T - MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); - computeDiagonalPartOfSqrt(sqrtT, T); - computeOffDiagonalPartOfSqrt(sqrtT, T); - - // Compute square root of m_A - result = U * sqrtT * U.adjoint(); + result.resize(m_A.rows(), m_A.cols()); + computeDiagonalPartOfSqrt(result, m_A); + computeOffDiagonalPartOfSqrt(result, m_A); } // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size @@ -291,17 +282,13 @@ template void MatrixSquareRootTriangular::compute(ResultType &result) { using std::sqrt; - // Compute Schur decomposition of m_A - const ComplexSchur schurOfA(m_A); - const MatrixType& T = schurOfA.matrixT(); - const MatrixType& U = schurOfA.matrixU(); - // Compute square root of T and store it in upper triangular part of result + // Compute square root of m_A and store it in upper triangular part of result // This uses that the square root of triangular matrices can be computed directly. result.resize(m_A.rows(), m_A.cols()); typedef typename MatrixType::Index Index; for (Index i = 0; i < m_A.rows(); i++) { - result.coeffRef(i,i) = sqrt(T.coeff(i,i)); + result.coeffRef(i,i) = sqrt(m_A.coeff(i,i)); } for (Index j = 1; j < m_A.cols(); j++) { for (Index i = j-1; i >= 0; i--) { @@ -309,14 +296,9 @@ void MatrixSquareRootTriangular::compute(ResultType &result) // if i = j-1, then segment has length 0 so tmp = 0 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value(); // denominator may be zero if original matrix is singular - result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); + result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); } } - - // Compute square root of m_A as U * result * U.adjoint() - MatrixType tmp; - tmp.noalias() = U * result.template triangularView(); - result.noalias() = tmp * U.adjoint(); } @@ -373,9 +355,8 @@ class MatrixSquareRoot const MatrixType& U = schurOfA.matrixU(); // Compute square root of T - MatrixSquareRootQuasiTriangular tmp(T); - MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); - tmp.compute(sqrtT); + MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols()); + MatrixSquareRootQuasiTriangular(T).compute(sqrtT); // Compute square root of m_A result = U * sqrtT * U.adjoint(); @@ -407,12 +388,11 @@ class MatrixSquareRoot const MatrixType& U = schurOfA.matrixU(); // Compute square root of T - MatrixSquareRootTriangular tmp(T); - MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); - tmp.compute(sqrtT); + MatrixType sqrtT; + MatrixSquareRootTriangular(T).compute(sqrtT); // Compute square root of m_A - result = U * sqrtT * U.adjoint(); + result = U * (sqrtT.template triangularView() * U.adjoint()); } private: -- cgit v1.2.3