diff options
author | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2013-07-22 13:56:15 +0100 |
---|---|---|
committer | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2013-07-22 13:56:15 +0100 |
commit | 084dc63b4ccfcc9a83a12973505af74a8bc32839 (patch) | |
tree | 75b3adbabc05897ff38f446e1c6177e273d8b767 | |
parent | 463343fb37cbd76e2527c560557a4d27a3d29c80 (diff) |
Clean-up of MatrixSquareRoot.
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h | 2 | ||||
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixPower.h | 2 | ||||
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h | 413 | ||||
-rw-r--r-- | unsupported/test/matrix_power.cpp | 8 |
4 files changed, 163 insertions, 262 deletions
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h index 33cfadfb4..4b1eb5a34 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h @@ -141,7 +141,7 @@ void MatrixLogarithmAtomic<MatrixType>::computeBig(const MatrixType& A, MatrixTy break; ++numberOfExtraSquareRoots; } - MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT); + matrix_sqrt_triangular(T, sqrtT); T = sqrtT.template triangularView<Upper>(); ++numberOfSquareRoots; } diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h index 5548bd95c..ee665c18e 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h @@ -219,7 +219,7 @@ void MatrixPowerAtomic<MatrixType>::computeBig(ResultType& res) const break; hasExtraSquareRoot = true; } - MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT); + matrix_sqrt_triangular(T, sqrtT); T = sqrtT.template triangularView<Upper>(); ++numberOfSquareRoots; } diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h index 0cd39ebe4..0261d4aa9 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h @@ -12,133 +12,16 @@ namespace Eigen { -/** \ingroup MatrixFunctions_Module - * \brief Class for computing matrix square roots of upper quasi-triangular matrices. - * \tparam MatrixType type of the argument of the matrix square root, - * expected to be an instantiation of the Matrix class template. - * - * This class computes the square root of the upper quasi-triangular - * matrix stored in the upper Hessenberg part of the matrix passed to - * the constructor. - * - * \sa MatrixSquareRoot, MatrixSquareRootTriangular - */ -template <typename MatrixType> -class MatrixSquareRootQuasiTriangular : internal::noncopyable -{ - public: - - /** \brief Constructor. - * - * \param[in] A upper quasi-triangular matrix whose square root - * is to be computed. - * - * The class stores a reference to \p A, so it should not be - * changed (or destroyed) before compute() is called. - */ - explicit MatrixSquareRootQuasiTriangular(const MatrixType& A) - : m_A(A) - { - eigen_assert(A.rows() == A.cols()); - } - - /** \brief Compute the matrix square root - * - * \param[out] result square root of \p A, as specified in the constructor. - * - * Only the upper Hessenberg part of \p result is updated, the - * rest is not touched. See MatrixBase::sqrt() for details on - * how this computation is implemented. - */ - template <typename ResultType> void compute(ResultType &result); - - private: - typedef typename MatrixType::Index Index; - typedef typename MatrixType::Scalar Scalar; - - void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); - void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); - void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i); - void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - - template <typename SmallMatrixType> - static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, - const SmallMatrixType& B, const SmallMatrixType& C); - - const MatrixType& m_A; -}; - -template <typename MatrixType> -template <typename ResultType> -void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result) -{ - result.resize(m_A.rows(), m_A.cols()); - computeDiagonalPartOfSqrt(result, m_A); - computeOffDiagonalPartOfSqrt(result, m_A); -} - -// pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size -// post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T -template <typename MatrixType> -void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, - const MatrixType& T) -{ - using std::sqrt; - const Index size = m_A.rows(); - for (Index i = 0; i < size; i++) { - if (i == size - 1 || T.coeff(i+1, i) == 0) { - eigen_assert(T(i,i) >= 0); - sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i)); - } - else { - compute2x2diagonalBlock(sqrtT, T, i); - ++i; - } - } -} - -// pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T. -// post: sqrtT is the square root of T. -template <typename MatrixType> -void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, - const MatrixType& T) -{ - const Index size = m_A.rows(); - for (Index j = 1; j < size; j++) { - if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block - continue; - for (Index i = j-1; i >= 0; i--) { - if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block - continue; - bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0); - bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0); - if (iBlockIs2x2 && jBlockIs2x2) - compute2x2offDiagonalBlock(sqrtT, T, i, j); - else if (iBlockIs2x2 && !jBlockIs2x2) - compute2x1offDiagonalBlock(sqrtT, T, i, j); - else if (!iBlockIs2x2 && jBlockIs2x2) - compute1x2offDiagonalBlock(sqrtT, T, i, j); - else if (!iBlockIs2x2 && !jBlockIs2x2) - compute1x1offDiagonalBlock(sqrtT, T, i, j); - } - } -} +namespace internal { // pre: T.block(i,i,2,2) has complex conjugate eigenvalues // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2) -template <typename MatrixType> -void MatrixSquareRootQuasiTriangular<MatrixType> - ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i) +template <typename MatrixType, typename ResultType> +void matrix_sqrt_quasi_triangular_2x2_diagonal_block(const MatrixType& T, typename MatrixType::Index i, ResultType& sqrtT) { // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere // in EigenSolver. If we expose it, we could call it directly from here. + typedef typename traits<MatrixType>::Scalar Scalar; Matrix<Scalar,2,2> block = T.template block<2,2>(i,i); EigenSolver<Matrix<Scalar,2,2> > es(block); sqrtT.template block<2,2>(i,i) @@ -148,21 +31,19 @@ void MatrixSquareRootQuasiTriangular<MatrixType> // pre: block structure of T is such that (i,j) is a 1x1 block, // all blocks of sqrtT to left of and below (i,j) are correct // post: sqrtT(i,j) has the correct value -template <typename MatrixType> -void MatrixSquareRootQuasiTriangular<MatrixType> - ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j) +template <typename MatrixType, typename ResultType> +void matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) { + typedef typename traits<MatrixType>::Scalar Scalar; Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value(); sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j)); } // similar to compute1x1offDiagonalBlock() -template <typename MatrixType> -void MatrixSquareRootQuasiTriangular<MatrixType> - ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j) +template <typename MatrixType, typename ResultType> +void matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) { + typedef typename traits<MatrixType>::Scalar Scalar; Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j); if (j-i > 1) rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2); @@ -172,11 +53,10 @@ void MatrixSquareRootQuasiTriangular<MatrixType> } // similar to compute1x1offDiagonalBlock() -template <typename MatrixType> -void MatrixSquareRootQuasiTriangular<MatrixType> - ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j) +template <typename MatrixType, typename ResultType> +void matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) { + typedef typename traits<MatrixType>::Scalar Scalar; Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j); if (j-i > 2) rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1); @@ -186,31 +66,25 @@ void MatrixSquareRootQuasiTriangular<MatrixType> } // similar to compute1x1offDiagonalBlock() -template <typename MatrixType> -void MatrixSquareRootQuasiTriangular<MatrixType> - ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j) +template <typename MatrixType, typename ResultType> +void matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) { + typedef typename traits<MatrixType>::Scalar Scalar; Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i); Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j); Matrix<Scalar,2,2> C = T.template block<2,2>(i,j); if (j-i > 2) C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2); Matrix<Scalar,2,2> X; - solveAuxiliaryEquation(X, A, B, C); + matrix_sqrt_quasi_triangular_solve_auxiliary_equation(X, A, B, C); sqrtT.template block<2,2>(i,j) = X; } // solves the equation A X + X B = C where all matrices are 2-by-2 template <typename MatrixType> -template <typename SmallMatrixType> -void MatrixSquareRootQuasiTriangular<MatrixType> - ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, - const SmallMatrixType& B, const SmallMatrixType& C) +void matrix_sqrt_quasi_triangular_solve_auxiliary_equation(MatrixType& X, const MatrixType& A, const MatrixType& B, const MatrixType& C) { - EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value), - EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); - + typedef typename traits<MatrixType>::Scalar Scalar; Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero(); coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0); coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1); @@ -241,164 +115,193 @@ void MatrixSquareRootQuasiTriangular<MatrixType> } +// pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size +// post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T +template <typename MatrixType, typename ResultType> +void matrix_sqrt_quasi_triangular_diagonal(const MatrixType& T, ResultType& sqrtT) +{ + using std::sqrt; + typedef typename MatrixType::Index Index; + const Index size = T.rows(); + for (Index i = 0; i < size; i++) { + if (i == size - 1 || T.coeff(i+1, i) == 0) { + eigen_assert(T(i,i) >= 0); + sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i)); + } + else { + matrix_sqrt_quasi_triangular_2x2_diagonal_block(T, i, sqrtT); + ++i; + } + } +} + +// pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T. +// post: sqrtT is the square root of T. +template <typename MatrixType, typename ResultType> +void matrix_sqrt_quasi_triangular_off_diagonal(const MatrixType& T, ResultType& sqrtT) +{ + typedef typename MatrixType::Index Index; + const Index size = T.rows(); + for (Index j = 1; j < size; j++) { + if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block + continue; + for (Index i = j-1; i >= 0; i--) { + if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block + continue; + bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0); + bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0); + if (iBlockIs2x2 && jBlockIs2x2) + matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(T, i, j, sqrtT); + else if (iBlockIs2x2 && !jBlockIs2x2) + matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(T, i, j, sqrtT); + else if (!iBlockIs2x2 && jBlockIs2x2) + matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(T, i, j, sqrtT); + else if (!iBlockIs2x2 && !jBlockIs2x2) + matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(T, i, j, sqrtT); + } + } +} + +} // end of namespace internal + /** \ingroup MatrixFunctions_Module - * \brief Class for computing matrix square roots of upper triangular matrices. - * \tparam MatrixType type of the argument of the matrix square root, + * \brief Compute matrix square root of quasi-triangular matrix. + * + * \tparam MatrixType type of \p arg, the argument of matrix square root, * expected to be an instantiation of the Matrix class template. + * \tparam ResultType type of \p result, where result is to be stored. + * \param[in] arg argument of matrix square root. + * \param[out] result matrix square root of upper Hessenberg part of \p arg. * - * This class computes the square root of the upper triangular matrix - * stored in the upper triangular part (including the diagonal) of - * the matrix passed to the constructor. + * This function computes the square root of the upper quasi-triangular matrix stored in the upper + * Hessenberg part of \p arg. Only the upper Hessenberg part of \p result is updated, the rest is + * not touched. See MatrixBase::sqrt() for details on how this computation is implemented. * * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular */ -template <typename MatrixType> -class MatrixSquareRootTriangular : internal::noncopyable +template <typename MatrixType, typename ResultType> +void matrix_sqrt_quasi_triangular(const MatrixType &arg, ResultType &result) { - public: - explicit MatrixSquareRootTriangular(const MatrixType& A) - : m_A(A) - { - eigen_assert(A.rows() == A.cols()); - } - - /** \brief Compute the matrix square root - * - * \param[out] result square root of \p A, as specified in the constructor. - * - * Only the upper triangular part (including the diagonal) of - * \p result is updated, the rest is not touched. See - * MatrixBase::sqrt() for details on how this computation is - * implemented. - */ - template <typename ResultType> void compute(ResultType &result); + eigen_assert(arg.rows() == arg.cols()); + result.resize(arg.rows(), arg.cols()); + internal::matrix_sqrt_quasi_triangular_diagonal(arg, result); + internal::matrix_sqrt_quasi_triangular_off_diagonal(arg, result); +} - private: - const MatrixType& m_A; -}; -template <typename MatrixType> -template <typename ResultType> -void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result) +/** \ingroup MatrixFunctions_Module + * \brief Compute matrix square root of triangular matrix. + * + * \tparam MatrixType type of \p arg, the argument of matrix square root, + * expected to be an instantiation of the Matrix class template. + * \tparam ResultType type of \p result, where result is to be stored. + * \param[in] arg argument of matrix square root. + * \param[out] result matrix square root of upper triangular part of \p arg. + * + * Only the upper triangular part (including the diagonal) of \p result is updated, the rest is not + * touched. See MatrixBase::sqrt() for details on how this computation is implemented. + * + * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular + */ +template <typename MatrixType, typename ResultType> +void matrix_sqrt_triangular(const MatrixType &arg, ResultType &result) { using std::sqrt; + typedef typename MatrixType::Index Index; + typedef typename MatrixType::Scalar Scalar; - // Compute square root of m_A and store it in upper triangular part of result + eigen_assert(arg.rows() == arg.cols()); + + // Compute square root of arg and store it in upper triangular part of result // This uses that the square root of triangular matrices can be computed directly. - result.resize(m_A.rows(), m_A.cols()); - typedef typename MatrixType::Index Index; - for (Index i = 0; i < m_A.rows(); i++) { - result.coeffRef(i,i) = sqrt(m_A.coeff(i,i)); + result.resize(arg.rows(), arg.cols()); + for (Index i = 0; i < arg.rows(); i++) { + result.coeffRef(i,i) = sqrt(arg.coeff(i,i)); } - for (Index j = 1; j < m_A.cols(); j++) { + for (Index j = 1; j < arg.cols(); j++) { for (Index i = j-1; i >= 0; i--) { - typedef typename MatrixType::Scalar Scalar; // if i = j-1, then segment has length 0 so tmp = 0 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value(); // denominator may be zero if original matrix is singular - result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); + result.coeffRef(i,j) = (arg.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); } } } +namespace internal { + /** \ingroup MatrixFunctions_Module - * \brief Class for computing matrix square roots of general matrices. + * \brief Helper struct for computing matrix square roots of general matrices. * \tparam MatrixType type of the argument of the matrix square root, * expected to be an instantiation of the Matrix class template. * * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt() */ template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex> -class MatrixSquareRoot +struct matrix_sqrt_compute { - public: - - /** \brief Constructor. - * - * \param[in] A matrix whose square root is to be computed. - * - * The class stores a reference to \p A, so it should not be - * changed (or destroyed) before compute() is called. - */ - explicit MatrixSquareRoot(const MatrixType& A); - - /** \brief Compute the matrix square root - * - * \param[out] result square root of \p A, as specified in the constructor. - * - * See MatrixBase::sqrt() for details on how this computation is - * implemented. - */ - template <typename ResultType> void compute(ResultType &result); + /** \brief Compute the matrix square root + * + * \param[in] arg matrix whose square root is to be computed. + * \param[out] result square root of \p arg. + * + * See MatrixBase::sqrt() for details on how this computation is implemented. + */ + template <typename ResultType> static void run(const MatrixType &arg, ResultType &result); }; // ********** Partial specialization for real matrices ********** template <typename MatrixType> -class MatrixSquareRoot<MatrixType, 0> +struct matrix_sqrt_compute<MatrixType, 0> { - public: - - explicit MatrixSquareRoot(const MatrixType& A) - : m_A(A) - { - eigen_assert(A.rows() == A.cols()); - } - - template <typename ResultType> void compute(ResultType &result) - { - // Compute Schur decomposition of m_A - const RealSchur<MatrixType> schurOfA(m_A); - const MatrixType& T = schurOfA.matrixT(); - const MatrixType& U = schurOfA.matrixU(); - - // Compute square root of T - MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols()); - MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT); + template <typename ResultType> + static void run(const MatrixType &arg, ResultType &result) + { + eigen_assert(arg.rows() == arg.cols()); + + // Compute Schur decomposition of arg + const RealSchur<MatrixType> schurOfA(arg); + const MatrixType& T = schurOfA.matrixT(); + const MatrixType& U = schurOfA.matrixU(); - // Compute square root of m_A - result = U * sqrtT * U.adjoint(); - } + // Compute square root of T + MatrixType sqrtT = MatrixType::Zero(arg.rows(), arg.cols()); + matrix_sqrt_quasi_triangular(T, sqrtT); - private: - const MatrixType& m_A; + // Compute square root of arg + result = U * sqrtT * U.adjoint(); + } }; // ********** Partial specialization for complex matrices ********** template <typename MatrixType> -class MatrixSquareRoot<MatrixType, 1> : internal::noncopyable +struct matrix_sqrt_compute<MatrixType, 1> { - public: - - explicit MatrixSquareRoot(const MatrixType& A) - : m_A(A) - { - eigen_assert(A.rows() == A.cols()); - } - - template <typename ResultType> void compute(ResultType &result) - { - // Compute Schur decomposition of m_A - const ComplexSchur<MatrixType> schurOfA(m_A); - const MatrixType& T = schurOfA.matrixT(); - const MatrixType& U = schurOfA.matrixU(); + template <typename ResultType> + static void run(const MatrixType &arg, ResultType &result) + { + eigen_assert(arg.rows() == arg.cols()); + + // Compute Schur decomposition of arg + const ComplexSchur<MatrixType> schurOfA(arg); + const MatrixType& T = schurOfA.matrixT(); + const MatrixType& U = schurOfA.matrixU(); - // Compute square root of T - MatrixType sqrtT; - MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT); + // Compute square root of T + MatrixType sqrtT; + matrix_sqrt_triangular(T, sqrtT); - // Compute square root of m_A - result = U * (sqrtT.template triangularView<Upper>() * U.adjoint()); - } - - private: - const MatrixType& m_A; + // Compute square root of arg + result = U * (sqrtT.template triangularView<Upper>() * U.adjoint()); + } }; +} // end namespace internal /** \ingroup MatrixFunctions_Module * @@ -432,9 +335,9 @@ template<typename Derived> class MatrixSquareRootReturnValue template <typename ResultType> inline void evalTo(ResultType& result) const { - const typename Derived::PlainObject srcEvaluated = m_src.eval(); - MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated); - me.compute(result); + typedef typename Derived::PlainObject PlainObject; + const PlainObject srcEvaluated = m_src.eval(); + internal::matrix_sqrt_compute<PlainObject>::run(srcEvaluated, result); } Index rows() const { return m_src.rows(); } diff --git a/unsupported/test/matrix_power.cpp b/unsupported/test/matrix_power.cpp index 849e4287b..4c4cac509 100644 --- a/unsupported/test/matrix_power.cpp +++ b/unsupported/test/matrix_power.cpp @@ -100,8 +100,6 @@ template<typename MatrixType> void testSingular(MatrixType m, double tol) { const int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex; - typedef typename internal::conditional< IsComplex, MatrixSquareRootTriangular<MatrixType>, - MatrixSquareRootQuasiTriangular<MatrixType> >::type SquareRootType; typedef typename internal::conditional<IsComplex, TriangularView<MatrixType,Upper>, const MatrixType&>::type TriangularType; typename internal::conditional< IsComplex, ComplexSchur<MatrixType>, RealSchur<MatrixType> >::type schur; MatrixType T; @@ -116,13 +114,13 @@ void testSingular(MatrixType m, double tol) processTriangularMatrix<MatrixType>::run(m, T, U); MatrixPower<MatrixType> mpow(m); - SquareRootType(T).compute(T); + T = T.sqrt(); VERIFY(mpow(0.5).isApprox(U * (TriangularType(T) * U.adjoint()), tol)); - SquareRootType(T).compute(T); + T = T.sqrt(); VERIFY(mpow(0.25).isApprox(U * (TriangularType(T) * U.adjoint()), tol)); - SquareRootType(T).compute(T); + T = T.sqrt(); VERIFY(mpow(0.125).isApprox(U * (TriangularType(T) * U.adjoint()), tol)); } } |