aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/matrix_square_root.cpp
diff options
context:
space:
mode:
authorGravatar Jitse Niesen <jitse@maths.leeds.ac.uk>2011-05-08 22:18:37 +0100
committerGravatar Jitse Niesen <jitse@maths.leeds.ac.uk>2011-05-08 22:18:37 +0100
commit6e1573f66a60511a00f6c7a54cf0918d23007c3e (patch)
treefb8de5e9ed83cc500e1f1094084b1169c83d53c4 /unsupported/test/matrix_square_root.cpp
parent6b4e215710dd5c12ad1fe8e820875674bdd849c8 (diff)
Implement square root for real matrices via Schur.
Diffstat (limited to 'unsupported/test/matrix_square_root.cpp')
-rw-r--r--unsupported/test/matrix_square_root.cpp43
1 files changed, 37 insertions, 6 deletions
diff --git a/unsupported/test/matrix_square_root.cpp b/unsupported/test/matrix_square_root.cpp
index cd2c6cfc4..56b86a288 100644
--- a/unsupported/test/matrix_square_root.cpp
+++ b/unsupported/test/matrix_square_root.cpp
@@ -25,16 +25,45 @@
#include "main.h"
#include <unsupported/Eigen/MatrixFunctions>
+template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
+struct generateTestMatrix;
+
+// for real matrices, make sure none of the eigenvalues are negative
+template <typename MatrixType>
+struct generateTestMatrix<MatrixType,0>
+{
+ static void run(MatrixType& result, typename MatrixType::Index size)
+ {
+ MatrixType mat = MatrixType::Random(size, size);
+ EigenSolver<MatrixType> es(mat);
+ typename EigenSolver<MatrixType>::EigenvalueType eivals = es.eigenvalues();
+ for (typename MatrixType::Index i = 0; i < size; ++i) {
+ if (eivals(i).imag() == 0 && eivals(i).real() < 0)
+ eivals(i) = -eivals(i);
+ }
+ result = (es.eigenvectors() * eivals.asDiagonal() * es.eigenvectors().inverse()).real();
+ }
+};
+
+// for complex matrices, any matrix is fine
+template <typename MatrixType>
+struct generateTestMatrix<MatrixType,1>
+{
+ static void run(MatrixType& result, typename MatrixType::Index size)
+ {
+ result = MatrixType::Random(size, size);
+ }
+};
+
template<typename MatrixType>
void testMatrixSqrt(const MatrixType& m)
{
- typedef typename MatrixType::Index Index;
- const Index size = m.rows();
- MatrixType A = MatrixType::Random(size, size);
+ MatrixType A;
+ generateTestMatrix<MatrixType>::run(A, m.rows());
MatrixSquareRoot<MatrixType> msr(A);
- MatrixType S;
- msr.compute(S);
- VERIFY_IS_APPROX(S*S, A);
+ MatrixType sqrtA;
+ msr.compute(sqrtA);
+ VERIFY_IS_APPROX(sqrtA * sqrtA, A);
}
void test_matrix_square_root()
@@ -42,5 +71,7 @@ void test_matrix_square_root()
for (int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1(testMatrixSqrt(Matrix3cf()));
CALL_SUBTEST_2(testMatrixSqrt(MatrixXcd(12,12)));
+ CALL_SUBTEST_3(testMatrixSqrt(Matrix4f()));
+ CALL_SUBTEST_4(testMatrixSqrt(Matrix<double,Dynamic,Dynamic,RowMajor>(9, 9)));
}
}