aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/SolveTriangular.h16
-rw-r--r--test/cholesky.cpp68
2 files changed, 73 insertions, 11 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h
index c25317989..960da31f3 100644
--- a/Eigen/src/Core/SolveTriangular.h
+++ b/Eigen/src/Core/SolveTriangular.h
@@ -53,7 +53,8 @@ struct ei_triangular_solver_selector;
template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor,1>
{
- typedef typename Rhs::Scalar Scalar;
+ typedef typename Lhs::Scalar LhsScalar;
+ typedef typename Rhs::Scalar RhsScalar;
typedef ei_blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
typedef typename Lhs::Index Index;
@@ -81,12 +82,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor
Index startRow = IsLower ? pi : pi-actualPanelWidth;
Index startCol = IsLower ? 0 : pi;
- ei_general_matrix_vector_product<Index,Scalar,RowMajor,LhsProductTraits::NeedToConjugate,Scalar,false>::run(
+ ei_general_matrix_vector_product<Index,LhsScalar,RowMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
actualPanelWidth, r,
&(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.outerStride(),
&(other.coeffRef(startCol)), other.innerStride(),
&other.coeffRef(startRow), other.innerStride(),
- Scalar(-1));
+ RhsScalar(-1));
}
for(Index k=0; k<actualPanelWidth; ++k)
@@ -107,13 +108,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor
template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor,1>
{
- typedef typename Rhs::Scalar Scalar;
- typedef typename ei_packet_traits<Scalar>::type Packet;
+ typedef typename Lhs::Scalar LhsScalar;
+ typedef typename Rhs::Scalar RhsScalar;
typedef ei_blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
typedef typename Lhs::Index Index;
enum {
- PacketSize = ei_packet_traits<Scalar>::size,
IsLower = ((Mode&Lower)==Lower)
};
@@ -148,11 +148,11 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor
// let's directly call the low level product function because:
// 1 - it is faster to compile
// 2 - it is slighlty faster at runtime
- ei_general_matrix_vector_product<Index,Scalar,ColMajor,LhsProductTraits::NeedToConjugate,Scalar,false>::run(
+ ei_general_matrix_vector_product<Index,LhsScalar,ColMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
r, actualPanelWidth,
&(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.outerStride(),
&other.coeff(startBlock), other.innerStride(),
- &(other.coeffRef(endBlock, 0)), other.innerStride(), Scalar(-1));
+ &(other.coeffRef(endBlock, 0)), other.innerStride(), RhsScalar(-1));
}
}
}
diff --git a/test/cholesky.cpp b/test/cholesky.cpp
index 136c69266..0edf9a793 100644
--- a/test/cholesky.cpp
+++ b/test/cholesky.cpp
@@ -170,6 +170,66 @@ template<typename MatrixType> void cholesky(const MatrixType& m)
}
+template<typename MatrixType> void cholesky_cplx(const MatrixType& m)
+{
+ // classic test
+ cholesky(m);
+
+ // test mixing real/scalar types
+
+ typedef typename MatrixType::Index Index;
+
+ Index rows = m.rows();
+ Index cols = m.cols();
+
+ typedef typename MatrixType::Scalar Scalar;
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> RealMatrixType;
+ typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> VectorType;
+
+ RealMatrixType a0 = RealMatrixType::Random(rows,cols);
+ VectorType vecB = VectorType::Random(rows), vecX(rows);
+ MatrixType matB = MatrixType::Random(rows,cols), matX(rows,cols);
+ RealMatrixType symm = a0 * a0.adjoint();
+ // let's make sure the matrix is not singular or near singular
+ for (int k=0; k<3; ++k)
+ {
+ RealMatrixType a1 = RealMatrixType::Random(rows,cols);
+ symm += a1 * a1.adjoint();
+ }
+
+ {
+ RealMatrixType symmLo = symm.template triangularView<Lower>();
+
+ LLT<RealMatrixType,Lower> chollo(symmLo);
+ VERIFY_IS_APPROX(symm, chollo.reconstructedMatrix());
+ vecX = chollo.solve(vecB);
+ VERIFY_IS_APPROX(symm * vecX, vecB);
+// matX = chollo.solve(matB);
+// VERIFY_IS_APPROX(symm * matX, matB);
+ }
+
+ // LDLT
+ {
+ int sign = ei_random<int>()%2 ? 1 : -1;
+
+ if(sign == -1)
+ {
+ symm = -symm; // test a negative matrix
+ }
+
+ RealMatrixType symmLo = symm.template triangularView<Lower>();
+
+ LDLT<RealMatrixType,Lower> ldltlo(symmLo);
+ VERIFY_IS_APPROX(symm, ldltlo.reconstructedMatrix());
+ vecX = ldltlo.solve(vecB);
+ VERIFY_IS_APPROX(symm * vecX, vecB);
+// matX = ldltlo.solve(matB);
+// VERIFY_IS_APPROX(symm * matX, matB);
+ }
+
+}
+
template<typename MatrixType> void cholesky_verify_assert()
{
MatrixType tmp;
@@ -192,14 +252,16 @@ template<typename MatrixType> void cholesky_verify_assert()
void test_cholesky()
{
+ int s;
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1( cholesky(Matrix<double,1,1>()) );
- CALL_SUBTEST_2( cholesky(MatrixXd(1,1)) );
CALL_SUBTEST_3( cholesky(Matrix2d()) );
CALL_SUBTEST_4( cholesky(Matrix3f()) );
CALL_SUBTEST_5( cholesky(Matrix4d()) );
- CALL_SUBTEST_2( cholesky(MatrixXd(200,200)) );
- CALL_SUBTEST_6( cholesky(MatrixXcd(100,100)) );
+ s = ei_random<int>(1,200);
+ CALL_SUBTEST_2( cholesky(MatrixXd(s,s)) );
+ s = ei_random<int>(1,100);
+ CALL_SUBTEST_6( cholesky_cplx(MatrixXcd(s,s)) );
}
CALL_SUBTEST_4( cholesky_verify_assert<Matrix3f>() );