From 6e1573f66a60511a00f6c7a54cf0918d23007c3e Mon Sep 17 00:00:00 2001 From: Jitse Niesen Date: Sun, 8 May 2011 22:18:37 +0100 Subject: Implement square root for real matrices via Schur. --- unsupported/test/matrix_square_root.cpp | 43 ++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 6 deletions(-) (limited to 'unsupported/test/matrix_square_root.cpp') diff --git a/unsupported/test/matrix_square_root.cpp b/unsupported/test/matrix_square_root.cpp index cd2c6cfc4..56b86a288 100644 --- a/unsupported/test/matrix_square_root.cpp +++ b/unsupported/test/matrix_square_root.cpp @@ -25,16 +25,45 @@ #include "main.h" #include +template ::Scalar>::IsComplex> +struct generateTestMatrix; + +// for real matrices, make sure none of the eigenvalues are negative +template +struct generateTestMatrix +{ + static void run(MatrixType& result, typename MatrixType::Index size) + { + MatrixType mat = MatrixType::Random(size, size); + EigenSolver es(mat); + typename EigenSolver::EigenvalueType eivals = es.eigenvalues(); + for (typename MatrixType::Index i = 0; i < size; ++i) { + if (eivals(i).imag() == 0 && eivals(i).real() < 0) + eivals(i) = -eivals(i); + } + result = (es.eigenvectors() * eivals.asDiagonal() * es.eigenvectors().inverse()).real(); + } +}; + +// for complex matrices, any matrix is fine +template +struct generateTestMatrix +{ + static void run(MatrixType& result, typename MatrixType::Index size) + { + result = MatrixType::Random(size, size); + } +}; + template void testMatrixSqrt(const MatrixType& m) { - typedef typename MatrixType::Index Index; - const Index size = m.rows(); - MatrixType A = MatrixType::Random(size, size); + MatrixType A; + generateTestMatrix::run(A, m.rows()); MatrixSquareRoot msr(A); - MatrixType S; - msr.compute(S); - VERIFY_IS_APPROX(S*S, A); + MatrixType sqrtA; + msr.compute(sqrtA); + VERIFY_IS_APPROX(sqrtA * sqrtA, A); } void test_matrix_square_root() @@ -42,5 +71,7 @@ void test_matrix_square_root() for (int i = 0; i < g_repeat; i++) { CALL_SUBTEST_1(testMatrixSqrt(Matrix3cf())); CALL_SUBTEST_2(testMatrixSqrt(MatrixXcd(12,12))); + CALL_SUBTEST_3(testMatrixSqrt(Matrix4f())); + CALL_SUBTEST_4(testMatrixSqrt(Matrix(9, 9))); } } -- cgit v1.2.3