aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h18
-rw-r--r--unsupported/test/matrix_function.cpp38
2 files changed, 48 insertions, 8 deletions
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
index 34bf78913..e363e779d 100644
--- a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
+++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
@@ -253,18 +253,19 @@ struct matrix_sqrt_compute
template <typename MatrixType>
struct matrix_sqrt_compute<MatrixType, 0>
{
+ typedef typename MatrixType::PlainObject PlainType;
template <typename ResultType>
static void run(const MatrixType &arg, ResultType &result)
{
eigen_assert(arg.rows() == arg.cols());
// Compute Schur decomposition of arg
- const RealSchur<MatrixType> schurOfA(arg);
- const MatrixType& T = schurOfA.matrixT();
- const MatrixType& U = schurOfA.matrixU();
+ const RealSchur<PlainType> schurOfA(arg);
+ const PlainType& T = schurOfA.matrixT();
+ const PlainType& U = schurOfA.matrixU();
// Compute square root of T
- MatrixType sqrtT = MatrixType::Zero(arg.rows(), arg.cols());
+ PlainType sqrtT = PlainType::Zero(arg.rows(), arg.cols());
matrix_sqrt_quasi_triangular(T, sqrtT);
// Compute square root of arg
@@ -278,18 +279,19 @@ struct matrix_sqrt_compute<MatrixType, 0>
template <typename MatrixType>
struct matrix_sqrt_compute<MatrixType, 1>
{
+ typedef typename MatrixType::PlainObject PlainType;
template <typename ResultType>
static void run(const MatrixType &arg, ResultType &result)
{
eigen_assert(arg.rows() == arg.cols());
// Compute Schur decomposition of arg
- const ComplexSchur<MatrixType> schurOfA(arg);
- const MatrixType& T = schurOfA.matrixT();
- const MatrixType& U = schurOfA.matrixU();
+ const ComplexSchur<PlainType> schurOfA(arg);
+ const PlainType& T = schurOfA.matrixT();
+ const PlainType& U = schurOfA.matrixU();
// Compute square root of T
- MatrixType sqrtT;
+ PlainType sqrtT;
matrix_sqrt_triangular(T, sqrtT);
// Compute square root of arg
diff --git a/unsupported/test/matrix_function.cpp b/unsupported/test/matrix_function.cpp
index 2049b8ba0..6d753737d 100644
--- a/unsupported/test/matrix_function.cpp
+++ b/unsupported/test/matrix_function.cpp
@@ -177,6 +177,39 @@ void testMatrixType(const MatrixType& m)
}
}
+template<typename MatrixType>
+void testMapRef(const MatrixType& A)
+{
+ // Test if passing Ref and Map objects is possible
+ // (Regression test for Bug #1796)
+ Index size = A.rows();
+ MatrixType X; X.setRandom(size, size);
+ MatrixType Y(size,size);
+ Ref< MatrixType> R(Y);
+ Ref<const MatrixType> Rc(X);
+ Map< MatrixType> M(Y.data(), size, size);
+ Map<const MatrixType> Mc(X.data(), size, size);
+
+ X = X*X; // make sure sqrt is possible
+ Y = X.sqrt();
+ R = Rc.sqrt();
+ M = Mc.sqrt();
+ Y = X.exp();
+ R = Rc.exp();
+ M = Mc.exp();
+ X = Y; // make sure log is possible
+ Y = X.log();
+ R = Rc.log();
+ M = Mc.log();
+
+ Y = X.cos() + Rc.cos() + Mc.cos();
+ Y = X.sin() + Rc.sin() + Mc.sin();
+
+ Y = X.cosh() + Rc.cosh() + Mc.cosh();
+ Y = X.sinh() + Rc.sinh() + Mc.sinh();
+}
+
+
EIGEN_DECLARE_TEST(matrix_function)
{
CALL_SUBTEST_1(testMatrixType(Matrix<float,1,1>()));
@@ -186,4 +219,9 @@ EIGEN_DECLARE_TEST(matrix_function)
CALL_SUBTEST_5(testMatrixType(Matrix<double,5,5,RowMajor>()));
CALL_SUBTEST_6(testMatrixType(Matrix4cd()));
CALL_SUBTEST_7(testMatrixType(MatrixXd(13,13)));
+
+ CALL_SUBTEST_1(testMapRef(Matrix<float,1,1>()));
+ CALL_SUBTEST_2(testMapRef(Matrix3cf()));
+ CALL_SUBTEST_3(testMapRef(MatrixXf(8,8)));
+ CALL_SUBTEST_7(testMapRef(MatrixXd(13,13)));
}