diff options
author | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2011-08-25 07:42:21 +0100 |
---|---|---|
committer | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2011-08-25 07:42:21 +0100 |
commit | c01ed935dd6fd536134c26c2eda7511269c28a6f (patch) | |
tree | 45216d5407ea0dd062f7e5fc9d68296fe8a92832 /unsupported | |
parent | 8ddd1e390b43b9d86c02fbdacdf7b70322eecfff (diff) |
Split code for (quasi)triangular matrices from MatrixSquareRoot.
This way, (quasi)triangular matrices can avoid the costly Schur decomposition.
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h | 269 |
1 files changed, 190 insertions, 79 deletions
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h index b56fcf06b..ed1b5ee35 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h @@ -26,76 +26,71 @@ #define EIGEN_MATRIX_SQUARE_ROOT /** \ingroup MatrixFunctions_Module - * \brief Class for computing matrix square roots. - * \tparam MatrixType type of the argument of the matrix square root, - * expected to be an instantiation of the Matrix class template. + * \brief Class for computing matrix square roots of upper quasi-triangular matrices. + * \tparam MatrixType type of the argument of the matrix square root, + * expected to be an instantiation of the Matrix class template. + * + * This class computes the square root of the upper quasi-triangular + * matrix stored in the upper Hessenberg part of the matrix passed to + * the constructor. + * + * \sa MatrixSquareRoot, MatrixSquareRootTriangular */ -template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex> -class MatrixSquareRoot -{ +template <typename MatrixType> +class MatrixSquareRootQuasiTriangular +{ public: /** \brief Constructor. * - * \param[in] A matrix whose square root is to be computed. + * \param[in] A upper quasi-triangular matrix whose square root + * is to be computed. * * The class stores a reference to \p A, so it should not be * changed (or destroyed) before compute() is called. */ - MatrixSquareRoot(const MatrixType& A); - + MatrixSquareRootQuasiTriangular(const MatrixType& A) + : m_A(A) + { + eigen_assert(A.rows() == A.cols()); + } + /** \brief Compute the matrix square root * * \param[out] result square root of \p A, as specified in the constructor. * - * See MatrixBase::sqrt() for details on how this computation - * is implemented. + * Only the upper Hessenberg part of \p result is updated, the + * rest is not touched. See MatrixBase::sqrt() for details on + * how this computation is implemented. */ - template <typename ResultType> - void compute(ResultType &result); -}; - - -// ********** Partial specialization for real matrices ********** - -template <typename MatrixType> -class MatrixSquareRoot<MatrixType, 0> -{ -public: - MatrixSquareRoot(const MatrixType& A) - : m_A(A) - { - eigen_assert(A.rows() == A.cols()); - } - - template <typename ResultType> void compute(ResultType &result); + template <typename ResultType> void compute(ResultType &result); + + private: + typedef typename MatrixType::Index Index; + typedef typename MatrixType::Scalar Scalar; + + void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); + void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); + void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i); + void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j); + void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j); + void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j); + void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j); -private: - typedef typename MatrixType::Index Index; - typedef typename MatrixType::Scalar Scalar; + template <typename SmallMatrixType> + static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, + const SmallMatrixType& B, const SmallMatrixType& C); - void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); - void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); - void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i); - void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - - template <typename SmallMatrixType> - static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, - const SmallMatrixType& B, const SmallMatrixType& C); - - const MatrixType& m_A; + const MatrixType& m_A; }; template <typename MatrixType> template <typename ResultType> -void MatrixSquareRoot<MatrixType, 0>::compute(ResultType &result) +void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result) { // Compute Schur decomposition of m_A const RealSchur<MatrixType> schurOfA(m_A); @@ -114,7 +109,8 @@ void MatrixSquareRoot<MatrixType, 0>::compute(ResultType &result) // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T template <typename MatrixType> -void MatrixSquareRoot<MatrixType, 0>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T) +void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, + const MatrixType& T) { const Index size = m_A.rows(); for (Index i = 0; i < size; i++) { @@ -132,7 +128,8 @@ void MatrixSquareRoot<MatrixType, 0>::computeDiagonalPartOfSqrt(MatrixType& sqrt // pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T. // post: sqrtT is the square root of T. template <typename MatrixType> -void MatrixSquareRoot<MatrixType, 0>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T) +void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, + const MatrixType& T) { const Index size = m_A.rows(); for (Index j = 1; j < size; j++) { @@ -158,9 +155,8 @@ void MatrixSquareRoot<MatrixType, 0>::computeOffDiagonalPartOfSqrt(MatrixType& s // pre: T.block(i,i,2,2) has complex conjugate eigenvalues // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2) template <typename MatrixType> -void MatrixSquareRoot<MatrixType, 0>::compute2x2diagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i) +void MatrixSquareRootQuasiTriangular<MatrixType> + ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i) { // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere // in EigenSolver. If we expose it, we could call it directly from here. @@ -174,10 +170,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute2x2diagonalBlock(MatrixType& sqrtT, // all blocks of sqrtT to left of and below (i,j) are correct // post: sqrtT(i,j) has the correct value template <typename MatrixType> -void MatrixSquareRoot<MatrixType, 0>::compute1x1offDiagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i, - typename MatrixType::Index j) +void MatrixSquareRootQuasiTriangular<MatrixType> + ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j) { Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value(); sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j)); @@ -185,10 +180,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute1x1offDiagonalBlock(MatrixType& sqr // similar to compute1x1offDiagonalBlock() template <typename MatrixType> -void MatrixSquareRoot<MatrixType, 0>::compute1x2offDiagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i, - typename MatrixType::Index j) +void MatrixSquareRootQuasiTriangular<MatrixType> + ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j) { Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j); if (j-i > 1) @@ -200,10 +194,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute1x2offDiagonalBlock(MatrixType& sqr // similar to compute1x1offDiagonalBlock() template <typename MatrixType> -void MatrixSquareRoot<MatrixType, 0>::compute2x1offDiagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i, - typename MatrixType::Index j) +void MatrixSquareRootQuasiTriangular<MatrixType> + ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j) { Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j); if (j-i > 2) @@ -215,10 +208,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute2x1offDiagonalBlock(MatrixType& sqr // similar to compute1x1offDiagonalBlock() template <typename MatrixType> -void MatrixSquareRoot<MatrixType, 0>::compute2x2offDiagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i, - typename MatrixType::Index j) +void MatrixSquareRootQuasiTriangular<MatrixType> + ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j) { Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i); Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j); @@ -233,10 +225,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute2x2offDiagonalBlock(MatrixType& sqr // solves the equation A X + X B = C where all matrices are 2-by-2 template <typename MatrixType> template <typename SmallMatrixType> -void MatrixSquareRoot<MatrixType, 0>::solveAuxiliaryEquation(SmallMatrixType& X, - const SmallMatrixType& A, - const SmallMatrixType& B, - const SmallMatrixType& C) +void MatrixSquareRootQuasiTriangular<MatrixType> + ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, + const SmallMatrixType& B, const SmallMatrixType& C) { EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); @@ -270,18 +261,37 @@ void MatrixSquareRoot<MatrixType, 0>::solveAuxiliaryEquation(SmallMatrixType& X, X.coeffRef(1,1) = result.coeff(3); } -// ********** Partial specialization for complex matrices ********** +/** \ingroup MatrixFunctions_Module + * \brief Class for computing matrix square roots of upper triangular matrices. + * \tparam MatrixType type of the argument of the matrix square root, + * expected to be an instantiation of the Matrix class template. + * + * This class computes the square root of the upper triangular matrix + * stored in the upper triangular part (including the diagonal) of + * the matrix passed to the constructor. + * + * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular + */ template <typename MatrixType> -class MatrixSquareRoot<MatrixType, 1> +class MatrixSquareRootTriangular { public: - MatrixSquareRoot(const MatrixType& A) + MatrixSquareRootTriangular(const MatrixType& A) : m_A(A) { eigen_assert(A.rows() == A.cols()); } + /** \brief Compute the matrix square root + * + * \param[out] result square root of \p A, as specified in the constructor. + * + * Only the upper triangular part (including the diagonal) of + * \p result is updated, the rest is not touched. See + * MatrixBase::sqrt() for details on how this computation is + * implemented. + */ template <typename ResultType> void compute(ResultType &result); private: @@ -290,7 +300,7 @@ class MatrixSquareRoot<MatrixType, 1> template <typename MatrixType> template <typename ResultType> -void MatrixSquareRoot<MatrixType, 1>::compute(ResultType &result) +void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result) { // Compute Schur decomposition of m_A const ComplexSchur<MatrixType> schurOfA(m_A); @@ -320,6 +330,107 @@ void MatrixSquareRoot<MatrixType, 1>::compute(ResultType &result) result.noalias() = tmp * U.adjoint(); } + +/** \ingroup MatrixFunctions_Module + * \brief Class for computing matrix square roots of general matrices. + * \tparam MatrixType type of the argument of the matrix square root, + * expected to be an instantiation of the Matrix class template. + * + * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt() + */ +template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex> +class MatrixSquareRoot +{ + public: + + /** \brief Constructor. + * + * \param[in] A matrix whose square root is to be computed. + * + * The class stores a reference to \p A, so it should not be + * changed (or destroyed) before compute() is called. + */ + MatrixSquareRoot(const MatrixType& A); + + /** \brief Compute the matrix square root + * + * \param[out] result square root of \p A, as specified in the constructor. + * + * See MatrixBase::sqrt() for details on how this computation is + * implemented. + */ + template <typename ResultType> void compute(ResultType &result); +}; + + +// ********** Partial specialization for real matrices ********** + +template <typename MatrixType> +class MatrixSquareRoot<MatrixType, 0> +{ + public: + + MatrixSquareRoot(const MatrixType& A) + : m_A(A) + { + eigen_assert(A.rows() == A.cols()); + } + + template <typename ResultType> void compute(ResultType &result) + { + // Compute Schur decomposition of m_A + const RealSchur<MatrixType> schurOfA(m_A); + const MatrixType& T = schurOfA.matrixT(); + const MatrixType& U = schurOfA.matrixU(); + + // Compute square root of T + MatrixSquareRootQuasiTriangular<MatrixType> tmp(T); + MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); + tmp.compute(sqrtT); + + // Compute square root of m_A + result = U * sqrtT * U.adjoint(); + } + + private: + const MatrixType& m_A; +}; + + +// ********** Partial specialization for complex matrices ********** + +template <typename MatrixType> +class MatrixSquareRoot<MatrixType, 1> +{ + public: + + MatrixSquareRoot(const MatrixType& A) + : m_A(A) + { + eigen_assert(A.rows() == A.cols()); + } + + template <typename ResultType> void compute(ResultType &result) + { + // Compute Schur decomposition of m_A + const ComplexSchur<MatrixType> schurOfA(m_A); + const MatrixType& T = schurOfA.matrixT(); + const MatrixType& U = schurOfA.matrixU(); + + // Compute square root of T + MatrixSquareRootTriangular<MatrixType> tmp(T); + MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); + tmp.compute(sqrtT); + + // Compute square root of m_A + result = U * sqrtT * U.adjoint(); + } + + private: + const MatrixType& m_A; +}; + + /** \ingroup MatrixFunctions_Module * * \brief Proxy for the matrix square root of some matrix (expression). |