diff options
author | Chen-Pang He <jdh8@ms63.hinet.net> | 2012-09-30 16:30:18 +0800 |
---|---|---|
committer | Chen-Pang He <jdh8@ms63.hinet.net> | 2012-09-30 16:30:18 +0800 |
commit | eb33d307af8cda6876b4eb334eaf258fbbfc8bff (patch) | |
tree | d89143cea8248ad449496ff0d471834590bca740 | |
parent | 332eb36436ff4d8f1a6f31749c49f53205e553bd (diff) |
Avoid Schur decomposition on (quasi-)triangular matrices. (Huge speed up!)
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h | 5 | ||||
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h | 42 |
2 files changed, 13 insertions, 34 deletions
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h index 6ec870d3e..166393f00 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h @@ -129,7 +129,7 @@ void MatrixLogarithmAtomic<MatrixType>::computeBig(const MatrixType& A, MatrixTy int numberOfSquareRoots = 0; int numberOfExtraSquareRoots = 0; int degree; - MatrixType T = A; + MatrixType T = A, sqrtT; const RealScalar maxNormForPade = maxPadeDegree<= 5? 5.3149729967117310e-1: // single precision maxPadeDegree<= 7? 2.6429608311114350e-1: // double precision maxPadeDegree<= 8? 2.32777776523703892094e-1L: // extended precision @@ -145,9 +145,8 @@ void MatrixLogarithmAtomic<MatrixType>::computeBig(const MatrixType& A, MatrixTy break; ++numberOfExtraSquareRoots; } - MatrixType sqrtT; MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT); - T = sqrtT; + T = sqrtT.template triangularView<Upper>(); ++numberOfSquareRoots; } diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h index abbf64096..c1a04036b 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h @@ -79,18 +79,9 @@ template <typename MatrixType> template <typename ResultType> void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result) { - // Compute Schur decomposition of m_A - const RealSchur<MatrixType> schurOfA(m_A); - const MatrixType& T = schurOfA.matrixT(); - const MatrixType& U = schurOfA.matrixU(); - - // Compute square root of T - MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); - computeDiagonalPartOfSqrt(sqrtT, T); - computeOffDiagonalPartOfSqrt(sqrtT, T); - - // Compute square root of m_A - result = U * sqrtT * U.adjoint(); + result.resize(m_A.rows(), m_A.cols()); + computeDiagonalPartOfSqrt(result, m_A); + computeOffDiagonalPartOfSqrt(result, m_A); } // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size @@ -291,17 +282,13 @@ template <typename ResultType> void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result) { using std::sqrt; - // Compute Schur decomposition of m_A - const ComplexSchur<MatrixType> schurOfA(m_A); - const MatrixType& T = schurOfA.matrixT(); - const MatrixType& U = schurOfA.matrixU(); - // Compute square root of T and store it in upper triangular part of result + // Compute square root of m_A and store it in upper triangular part of result // This uses that the square root of triangular matrices can be computed directly. result.resize(m_A.rows(), m_A.cols()); typedef typename MatrixType::Index Index; for (Index i = 0; i < m_A.rows(); i++) { - result.coeffRef(i,i) = sqrt(T.coeff(i,i)); + result.coeffRef(i,i) = sqrt(m_A.coeff(i,i)); } for (Index j = 1; j < m_A.cols(); j++) { for (Index i = j-1; i >= 0; i--) { @@ -309,14 +296,9 @@ void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result) // if i = j-1, then segment has length 0 so tmp = 0 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value(); // denominator may be zero if original matrix is singular - result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); + result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); } } - - // Compute square root of m_A as U * result * U.adjoint() - MatrixType tmp; - tmp.noalias() = U * result.template triangularView<Upper>(); - result.noalias() = tmp * U.adjoint(); } @@ -373,9 +355,8 @@ class MatrixSquareRoot<MatrixType, 0> const MatrixType& U = schurOfA.matrixU(); // Compute square root of T - MatrixSquareRootQuasiTriangular<MatrixType> tmp(T); - MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); - tmp.compute(sqrtT); + MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols()); + MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT); // Compute square root of m_A result = U * sqrtT * U.adjoint(); @@ -407,12 +388,11 @@ class MatrixSquareRoot<MatrixType, 1> const MatrixType& U = schurOfA.matrixU(); // Compute square root of T - MatrixSquareRootTriangular<MatrixType> tmp(T); - MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); - tmp.compute(sqrtT); + MatrixType sqrtT; + MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT); // Compute square root of m_A - result = U * sqrtT * U.adjoint(); + result = U * (sqrtT.template triangularView<Upper>() * U.adjoint()); } private: |