aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Jitse Niesen <jitse@maths.leeds.ac.uk>2011-08-25 07:42:21 +0100
committerGravatar Jitse Niesen <jitse@maths.leeds.ac.uk>2011-08-25 07:42:21 +0100
commitc01ed935dd6fd536134c26c2eda7511269c28a6f (patch)
tree45216d5407ea0dd062f7e5fc9d68296fe8a92832 /unsupported
parent8ddd1e390b43b9d86c02fbdacdf7b70322eecfff (diff)
Split code for (quasi)triangular matrices from MatrixSquareRoot.
This way, (quasi)triangular matrices can avoid the costly Schur decomposition.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h269
1 files changed, 190 insertions, 79 deletions
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
index b56fcf06b..ed1b5ee35 100644
--- a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
+++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
@@ -26,76 +26,71 @@
#define EIGEN_MATRIX_SQUARE_ROOT
/** \ingroup MatrixFunctions_Module
- * \brief Class for computing matrix square roots.
- * \tparam MatrixType type of the argument of the matrix square root,
- * expected to be an instantiation of the Matrix class template.
+ * \brief Class for computing matrix square roots of upper quasi-triangular matrices.
+ * \tparam MatrixType type of the argument of the matrix square root,
+ * expected to be an instantiation of the Matrix class template.
+ *
+ * This class computes the square root of the upper quasi-triangular
+ * matrix stored in the upper Hessenberg part of the matrix passed to
+ * the constructor.
+ *
+ * \sa MatrixSquareRoot, MatrixSquareRootTriangular
*/
-template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
-class MatrixSquareRoot
-{
+template <typename MatrixType>
+class MatrixSquareRootQuasiTriangular
+{
public:
/** \brief Constructor.
*
- * \param[in] A matrix whose square root is to be computed.
+ * \param[in] A upper quasi-triangular matrix whose square root
+ * is to be computed.
*
* The class stores a reference to \p A, so it should not be
* changed (or destroyed) before compute() is called.
*/
- MatrixSquareRoot(const MatrixType& A);
-
+ MatrixSquareRootQuasiTriangular(const MatrixType& A)
+ : m_A(A)
+ {
+ eigen_assert(A.rows() == A.cols());
+ }
+
/** \brief Compute the matrix square root
*
* \param[out] result square root of \p A, as specified in the constructor.
*
- * See MatrixBase::sqrt() for details on how this computation
- * is implemented.
+ * Only the upper Hessenberg part of \p result is updated, the
+ * rest is not touched. See MatrixBase::sqrt() for details on
+ * how this computation is implemented.
*/
- template <typename ResultType>
- void compute(ResultType &result);
-};
-
-
-// ********** Partial specialization for real matrices **********
-
-template <typename MatrixType>
-class MatrixSquareRoot<MatrixType, 0>
-{
-public:
- MatrixSquareRoot(const MatrixType& A)
- : m_A(A)
- {
- eigen_assert(A.rows() == A.cols());
- }
-
- template <typename ResultType> void compute(ResultType &result);
+ template <typename ResultType> void compute(ResultType &result);
+
+ private:
+ typedef typename MatrixType::Index Index;
+ typedef typename MatrixType::Scalar Scalar;
+
+ void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
+ void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
+ void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
+ void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
+ typename MatrixType::Index i, typename MatrixType::Index j);
+ void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
+ typename MatrixType::Index i, typename MatrixType::Index j);
+ void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
+ typename MatrixType::Index i, typename MatrixType::Index j);
+ void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
+ typename MatrixType::Index i, typename MatrixType::Index j);
-private:
- typedef typename MatrixType::Index Index;
- typedef typename MatrixType::Scalar Scalar;
+ template <typename SmallMatrixType>
+ static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
+ const SmallMatrixType& B, const SmallMatrixType& C);
- void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
- void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
- void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
- void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
- typename MatrixType::Index i, typename MatrixType::Index j);
- void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
- typename MatrixType::Index i, typename MatrixType::Index j);
- void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
- typename MatrixType::Index i, typename MatrixType::Index j);
- void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
- typename MatrixType::Index i, typename MatrixType::Index j);
-
- template <typename SmallMatrixType>
- static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
- const SmallMatrixType& B, const SmallMatrixType& C);
-
- const MatrixType& m_A;
+ const MatrixType& m_A;
};
template <typename MatrixType>
template <typename ResultType>
-void MatrixSquareRoot<MatrixType, 0>::compute(ResultType &result)
+void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
{
// Compute Schur decomposition of m_A
const RealSchur<MatrixType> schurOfA(m_A);
@@ -114,7 +109,8 @@ void MatrixSquareRoot<MatrixType, 0>::compute(ResultType &result)
// pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
// post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
template <typename MatrixType>
-void MatrixSquareRoot<MatrixType, 0>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T)
+void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
+ const MatrixType& T)
{
const Index size = m_A.rows();
for (Index i = 0; i < size; i++) {
@@ -132,7 +128,8 @@ void MatrixSquareRoot<MatrixType, 0>::computeDiagonalPartOfSqrt(MatrixType& sqrt
// pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
// post: sqrtT is the square root of T.
template <typename MatrixType>
-void MatrixSquareRoot<MatrixType, 0>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T)
+void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
+ const MatrixType& T)
{
const Index size = m_A.rows();
for (Index j = 1; j < size; j++) {
@@ -158,9 +155,8 @@ void MatrixSquareRoot<MatrixType, 0>::computeOffDiagonalPartOfSqrt(MatrixType& s
// pre: T.block(i,i,2,2) has complex conjugate eigenvalues
// post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
template <typename MatrixType>
-void MatrixSquareRoot<MatrixType, 0>::compute2x2diagonalBlock(MatrixType& sqrtT,
- const MatrixType& T,
- typename MatrixType::Index i)
+void MatrixSquareRootQuasiTriangular<MatrixType>
+ ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
{
// TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
// in EigenSolver. If we expose it, we could call it directly from here.
@@ -174,10 +170,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute2x2diagonalBlock(MatrixType& sqrtT,
// all blocks of sqrtT to left of and below (i,j) are correct
// post: sqrtT(i,j) has the correct value
template <typename MatrixType>
-void MatrixSquareRoot<MatrixType, 0>::compute1x1offDiagonalBlock(MatrixType& sqrtT,
- const MatrixType& T,
- typename MatrixType::Index i,
- typename MatrixType::Index j)
+void MatrixSquareRootQuasiTriangular<MatrixType>
+ ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
+ typename MatrixType::Index i, typename MatrixType::Index j)
{
Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
@@ -185,10 +180,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute1x1offDiagonalBlock(MatrixType& sqr
// similar to compute1x1offDiagonalBlock()
template <typename MatrixType>
-void MatrixSquareRoot<MatrixType, 0>::compute1x2offDiagonalBlock(MatrixType& sqrtT,
- const MatrixType& T,
- typename MatrixType::Index i,
- typename MatrixType::Index j)
+void MatrixSquareRootQuasiTriangular<MatrixType>
+ ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
+ typename MatrixType::Index i, typename MatrixType::Index j)
{
Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
if (j-i > 1)
@@ -200,10 +194,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute1x2offDiagonalBlock(MatrixType& sqr
// similar to compute1x1offDiagonalBlock()
template <typename MatrixType>
-void MatrixSquareRoot<MatrixType, 0>::compute2x1offDiagonalBlock(MatrixType& sqrtT,
- const MatrixType& T,
- typename MatrixType::Index i,
- typename MatrixType::Index j)
+void MatrixSquareRootQuasiTriangular<MatrixType>
+ ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
+ typename MatrixType::Index i, typename MatrixType::Index j)
{
Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
if (j-i > 2)
@@ -215,10 +208,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute2x1offDiagonalBlock(MatrixType& sqr
// similar to compute1x1offDiagonalBlock()
template <typename MatrixType>
-void MatrixSquareRoot<MatrixType, 0>::compute2x2offDiagonalBlock(MatrixType& sqrtT,
- const MatrixType& T,
- typename MatrixType::Index i,
- typename MatrixType::Index j)
+void MatrixSquareRootQuasiTriangular<MatrixType>
+ ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
+ typename MatrixType::Index i, typename MatrixType::Index j)
{
Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
@@ -233,10 +225,9 @@ void MatrixSquareRoot<MatrixType, 0>::compute2x2offDiagonalBlock(MatrixType& sqr
// solves the equation A X + X B = C where all matrices are 2-by-2
template <typename MatrixType>
template <typename SmallMatrixType>
-void MatrixSquareRoot<MatrixType, 0>::solveAuxiliaryEquation(SmallMatrixType& X,
- const SmallMatrixType& A,
- const SmallMatrixType& B,
- const SmallMatrixType& C)
+void MatrixSquareRootQuasiTriangular<MatrixType>
+ ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
+ const SmallMatrixType& B, const SmallMatrixType& C)
{
EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
@@ -270,18 +261,37 @@ void MatrixSquareRoot<MatrixType, 0>::solveAuxiliaryEquation(SmallMatrixType& X,
X.coeffRef(1,1) = result.coeff(3);
}
-// ********** Partial specialization for complex matrices **********
+/** \ingroup MatrixFunctions_Module
+ * \brief Class for computing matrix square roots of upper triangular matrices.
+ * \tparam MatrixType type of the argument of the matrix square root,
+ * expected to be an instantiation of the Matrix class template.
+ *
+ * This class computes the square root of the upper triangular matrix
+ * stored in the upper triangular part (including the diagonal) of
+ * the matrix passed to the constructor.
+ *
+ * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular
+ */
template <typename MatrixType>
-class MatrixSquareRoot<MatrixType, 1>
+class MatrixSquareRootTriangular
{
public:
- MatrixSquareRoot(const MatrixType& A)
+ MatrixSquareRootTriangular(const MatrixType& A)
: m_A(A)
{
eigen_assert(A.rows() == A.cols());
}
+ /** \brief Compute the matrix square root
+ *
+ * \param[out] result square root of \p A, as specified in the constructor.
+ *
+ * Only the upper triangular part (including the diagonal) of
+ * \p result is updated, the rest is not touched. See
+ * MatrixBase::sqrt() for details on how this computation is
+ * implemented.
+ */
template <typename ResultType> void compute(ResultType &result);
private:
@@ -290,7 +300,7 @@ class MatrixSquareRoot<MatrixType, 1>
template <typename MatrixType>
template <typename ResultType>
-void MatrixSquareRoot<MatrixType, 1>::compute(ResultType &result)
+void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
{
// Compute Schur decomposition of m_A
const ComplexSchur<MatrixType> schurOfA(m_A);
@@ -320,6 +330,107 @@ void MatrixSquareRoot<MatrixType, 1>::compute(ResultType &result)
result.noalias() = tmp * U.adjoint();
}
+
+/** \ingroup MatrixFunctions_Module
+ * \brief Class for computing matrix square roots of general matrices.
+ * \tparam MatrixType type of the argument of the matrix square root,
+ * expected to be an instantiation of the Matrix class template.
+ *
+ * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt()
+ */
+template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
+class MatrixSquareRoot
+{
+ public:
+
+ /** \brief Constructor.
+ *
+ * \param[in] A matrix whose square root is to be computed.
+ *
+ * The class stores a reference to \p A, so it should not be
+ * changed (or destroyed) before compute() is called.
+ */
+ MatrixSquareRoot(const MatrixType& A);
+
+ /** \brief Compute the matrix square root
+ *
+ * \param[out] result square root of \p A, as specified in the constructor.
+ *
+ * See MatrixBase::sqrt() for details on how this computation is
+ * implemented.
+ */
+ template <typename ResultType> void compute(ResultType &result);
+};
+
+
+// ********** Partial specialization for real matrices **********
+
+template <typename MatrixType>
+class MatrixSquareRoot<MatrixType, 0>
+{
+ public:
+
+ MatrixSquareRoot(const MatrixType& A)
+ : m_A(A)
+ {
+ eigen_assert(A.rows() == A.cols());
+ }
+
+ template <typename ResultType> void compute(ResultType &result)
+ {
+ // Compute Schur decomposition of m_A
+ const RealSchur<MatrixType> schurOfA(m_A);
+ const MatrixType& T = schurOfA.matrixT();
+ const MatrixType& U = schurOfA.matrixU();
+
+ // Compute square root of T
+ MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
+ MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
+ tmp.compute(sqrtT);
+
+ // Compute square root of m_A
+ result = U * sqrtT * U.adjoint();
+ }
+
+ private:
+ const MatrixType& m_A;
+};
+
+
+// ********** Partial specialization for complex matrices **********
+
+template <typename MatrixType>
+class MatrixSquareRoot<MatrixType, 1>
+{
+ public:
+
+ MatrixSquareRoot(const MatrixType& A)
+ : m_A(A)
+ {
+ eigen_assert(A.rows() == A.cols());
+ }
+
+ template <typename ResultType> void compute(ResultType &result)
+ {
+ // Compute Schur decomposition of m_A
+ const ComplexSchur<MatrixType> schurOfA(m_A);
+ const MatrixType& T = schurOfA.matrixT();
+ const MatrixType& U = schurOfA.matrixU();
+
+ // Compute square root of T
+ MatrixSquareRootTriangular<MatrixType> tmp(T);
+ MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
+ tmp.compute(sqrtT);
+
+ // Compute square root of m_A
+ result = U * sqrtT * U.adjoint();
+ }
+
+ private:
+ const MatrixType& m_A;
+};
+
+
/** \ingroup MatrixFunctions_Module
*
* \brief Proxy for the matrix square root of some matrix (expression).