diff options
author | Gael Guennebaud <g.gael@free.fr> | 2016-07-04 15:13:35 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2016-07-04 15:13:35 +0200 |
commit | 32a41ee659686fe1fb76156f7a55287acf14d4bb (patch) | |
tree | 55ba624046d61d3f57f14db2044e1c30cf2d3e14 | |
parent | 75e80792cc98b09d4ba92df67ab810d9af983e87 (diff) |
bug #707: add inplace decomposition through Ref<> for Cholesky, LU and QR decompositions.
-rw-r--r-- | Eigen/src/Cholesky/LDLT.h | 21 | ||||
-rw-r--r-- | Eigen/src/Cholesky/LLT.h | 16 | ||||
-rw-r--r-- | Eigen/src/LU/FullPivLU.h | 44 | ||||
-rw-r--r-- | Eigen/src/LU/PartialPivLU.h | 62 | ||||
-rw-r--r-- | Eigen/src/QR/ColPivHouseholderQR.h | 37 | ||||
-rw-r--r-- | Eigen/src/QR/CompleteOrthogonalDecomposition.h | 47 | ||||
-rw-r--r-- | Eigen/src/QR/FullPivHouseholderQR.h | 27 | ||||
-rw-r--r-- | Eigen/src/QR/HouseholderQR.h | 36 | ||||
-rw-r--r-- | test/CMakeLists.txt | 1 | ||||
-rw-r--r-- | test/inplace_decomposition.cpp | 110 |
10 files changed, 337 insertions, 64 deletions
diff --git a/Eigen/src/Cholesky/LDLT.h b/Eigen/src/Cholesky/LDLT.h index 538aff956..a31b3d6aa 100644 --- a/Eigen/src/Cholesky/LDLT.h +++ b/Eigen/src/Cholesky/LDLT.h @@ -52,7 +52,6 @@ template<typename _MatrixType, int _UpLo> class LDLT enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, ColsAtCompileTime = MatrixType::ColsAtCompileTime, - Options = MatrixType::Options & ~RowMajorBit, // these are the options for the TmpMatrixType, we need a ColMajor matrix here! MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime, UpLo = _UpLo @@ -61,7 +60,7 @@ template<typename _MatrixType, int _UpLo> class LDLT typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar; typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3 typedef typename MatrixType::StorageIndex StorageIndex; - typedef Matrix<Scalar, RowsAtCompileTime, 1, Options, MaxRowsAtCompileTime, 1> TmpMatrixType; + typedef Matrix<Scalar, RowsAtCompileTime, 1, 0, MaxRowsAtCompileTime, 1> TmpMatrixType; typedef Transpositions<RowsAtCompileTime, MaxRowsAtCompileTime> TranspositionType; typedef PermutationMatrix<RowsAtCompileTime, MaxRowsAtCompileTime> PermutationType; @@ -97,6 +96,7 @@ template<typename _MatrixType, int _UpLo> class LDLT /** \brief Constructor with decomposition * * This calculates the decomposition for the input \a matrix. + * * \sa LDLT(Index size) */ template<typename InputType> @@ -110,6 +110,23 @@ template<typename _MatrixType, int _UpLo> class LDLT compute(matrix.derived()); } + /** \brief Constructs a LDLT factorization from a given matrix + * + * This overloaded constructor is provided for inplace solving when \c MatrixType is a Eigen::Ref. + * + * \sa LDLT(const EigenBase&) + */ + template<typename InputType> + explicit LDLT(EigenBase<InputType>& matrix) + : m_matrix(matrix.derived()), + m_transpositions(matrix.rows()), + m_temporary(matrix.rows()), + m_sign(internal::ZeroSign), + m_isInitialized(false) + { + compute(matrix.derived()); + } + /** Clear any existing decomposition * \sa rankUpdate(w,sigma) */ diff --git a/Eigen/src/Cholesky/LLT.h b/Eigen/src/Cholesky/LLT.h index 19578b216..ad163c749 100644 --- a/Eigen/src/Cholesky/LLT.h +++ b/Eigen/src/Cholesky/LLT.h @@ -54,7 +54,6 @@ template<typename _MatrixType, int _UpLo> class LLT enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, ColsAtCompileTime = MatrixType::ColsAtCompileTime, - Options = MatrixType::Options, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime }; typedef typename MatrixType::Scalar Scalar; @@ -95,6 +94,21 @@ template<typename _MatrixType, int _UpLo> class LLT compute(matrix.derived()); } + /** \brief Constructs a LDLT factorization from a given matrix + * + * This overloaded constructor is provided for inplace solving when + * \c MatrixType is a Eigen::Ref. + * + * \sa LLT(const EigenBase&) + */ + template<typename InputType> + explicit LLT(EigenBase<InputType>& matrix) + : m_matrix(matrix.derived()), + m_isInitialized(false) + { + compute(matrix.derived()); + } + /** \returns a view of the upper triangular matrix U */ inline typename Traits::MatrixU matrixU() const { diff --git a/Eigen/src/LU/FullPivLU.h b/Eigen/src/LU/FullPivLU.h index 2d01b18c6..113b8c7b8 100644 --- a/Eigen/src/LU/FullPivLU.h +++ b/Eigen/src/LU/FullPivLU.h @@ -97,6 +97,15 @@ template<typename _MatrixType> class FullPivLU template<typename InputType> explicit FullPivLU(const EigenBase<InputType>& matrix); + /** \brief Constructs a LU factorization from a given matrix + * + * This overloaded constructor is provided for inplace solving when \c MatrixType is a Eigen::Ref. + * + * \sa FullPivLU(const EigenBase&) + */ + template<typename InputType> + explicit FullPivLU(EigenBase<InputType>& matrix); + /** Computes the LU decomposition of the given matrix. * * \param matrix the matrix of which to compute the LU decomposition. @@ -105,7 +114,11 @@ template<typename _MatrixType> class FullPivLU * \returns a reference to *this */ template<typename InputType> - FullPivLU& compute(const EigenBase<InputType>& matrix); + FullPivLU& compute(const EigenBase<InputType>& matrix) { + m_lu = matrix.derived(); + computeInPlace(); + return *this; + } /** \returns the LU decomposition matrix: the upper-triangular part is U, the * unit-lower-triangular part is L (at least for square matrices; in the non-square @@ -459,25 +472,28 @@ FullPivLU<MatrixType>::FullPivLU(const EigenBase<InputType>& matrix) template<typename MatrixType> template<typename InputType> -FullPivLU<MatrixType>& FullPivLU<MatrixType>::compute(const EigenBase<InputType>& matrix) +FullPivLU<MatrixType>::FullPivLU(EigenBase<InputType>& matrix) + : m_lu(matrix.derived()), + m_p(matrix.rows()), + m_q(matrix.cols()), + m_rowsTranspositions(matrix.rows()), + m_colsTranspositions(matrix.cols()), + m_isInitialized(false), + m_usePrescribedThreshold(false) { - check_template_parameters(); - - // the permutations are stored as int indices, so just to be sure: - eigen_assert(matrix.rows()<=NumTraits<int>::highest() && matrix.cols()<=NumTraits<int>::highest()); - - m_lu = matrix.derived(); - m_l1_norm = m_lu.cwiseAbs().colwise().sum().maxCoeff(); - computeInPlace(); - - m_isInitialized = true; - return *this; } template<typename MatrixType> void FullPivLU<MatrixType>::computeInPlace() { + check_template_parameters(); + + // the permutations are stored as int indices, so just to be sure: + eigen_assert(m_lu.rows()<=NumTraits<int>::highest() && m_lu.cols()<=NumTraits<int>::highest()); + + m_l1_norm = m_lu.cwiseAbs().colwise().sum().maxCoeff(); + const Index size = m_lu.diagonalSize(); const Index rows = m_lu.rows(); const Index cols = m_lu.cols(); @@ -557,6 +573,8 @@ void FullPivLU<MatrixType>::computeInPlace() m_q.applyTranspositionOnTheRight(k, m_colsTranspositions.coeff(k)); m_det_pq = (number_of_transpositions%2) ? -1 : 1; + + m_isInitialized = true; } template<typename MatrixType> diff --git a/Eigen/src/LU/PartialPivLU.h b/Eigen/src/LU/PartialPivLU.h index ac2902261..c862d9692 100644 --- a/Eigen/src/LU/PartialPivLU.h +++ b/Eigen/src/LU/PartialPivLU.h @@ -26,6 +26,17 @@ template<typename _MatrixType> struct traits<PartialPivLU<_MatrixType> > }; }; +template<typename T,typename Derived> +struct enable_if_ref; +// { +// typedef Derived type; +// }; + +template<typename T,typename Derived> +struct enable_if_ref<Ref<T>,Derived> { + typedef Derived type; +}; + } // end namespace internal /** \ingroup LU_Module @@ -102,8 +113,29 @@ template<typename _MatrixType> class PartialPivLU template<typename InputType> explicit PartialPivLU(const EigenBase<InputType>& matrix); + /** Constructor for inplace decomposition + * + * \param matrix the matrix of which to compute the LU decomposition. + * + * If \c MatrixType is an Eigen::Ref, then the storage of \a matrix will be shared + * between \a matrix and \c *this and the decomposition will take place in-place. + * The memory of \a matrix will be used througrough the lifetime of \c *this. In + * particular, further calls to \c this->compute(A) will still operate on the memory + * of \a matrix meaning. This also implies that the sizes of \c A must match the + * ones of \a matrix. + * + * \warning The matrix should have full rank (e.g. if it's square, it should be invertible). + * If you need to deal with non-full rank, use class FullPivLU instead. + */ + template<typename InputType> + explicit PartialPivLU(EigenBase<InputType>& matrix); + template<typename InputType> - PartialPivLU& compute(const EigenBase<InputType>& matrix); + PartialPivLU& compute(const EigenBase<InputType>& matrix) { + m_lu = matrix.derived(); + compute(); + return *this; + } /** \returns the LU decomposition matrix: the upper-triangular part is U, the * unit-lower-triangular part is L (at least for square matrices; in the non-square @@ -251,6 +283,8 @@ template<typename _MatrixType> class PartialPivLU EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); } + void compute(); + MatrixType m_lu; PermutationType m_p; TranspositionType m_rowsTranspositions; @@ -284,7 +318,7 @@ PartialPivLU<MatrixType>::PartialPivLU(Index size) template<typename MatrixType> template<typename InputType> PartialPivLU<MatrixType>::PartialPivLU(const EigenBase<InputType>& matrix) - : m_lu(matrix.rows(), matrix.rows()), + : m_lu(matrix.rows(),matrix.cols()), m_p(matrix.rows()), m_rowsTranspositions(matrix.rows()), m_l1_norm(0), @@ -294,6 +328,19 @@ PartialPivLU<MatrixType>::PartialPivLU(const EigenBase<InputType>& matrix) compute(matrix.derived()); } +template<typename MatrixType> +template<typename InputType> +PartialPivLU<MatrixType>::PartialPivLU(EigenBase<InputType>& matrix) + : m_lu(matrix.derived()), + m_p(matrix.rows()), + m_rowsTranspositions(matrix.rows()), + m_l1_norm(0), + m_det_p(0), + m_isInitialized(false) +{ + compute(); +} + namespace internal { /** \internal This is the blocked version of fullpivlu_unblocked() */ @@ -470,19 +517,17 @@ void partial_lu_inplace(MatrixType& lu, TranspositionType& row_transpositions, t } // end namespace internal template<typename MatrixType> -template<typename InputType> -PartialPivLU<MatrixType>& PartialPivLU<MatrixType>::compute(const EigenBase<InputType>& matrix) +void PartialPivLU<MatrixType>::compute() { check_template_parameters(); // the row permutation is stored as int indices, so just to be sure: - eigen_assert(matrix.rows()<NumTraits<int>::highest()); + eigen_assert(m_lu.rows()<NumTraits<int>::highest()); - m_lu = matrix.derived(); m_l1_norm = m_lu.cwiseAbs().colwise().sum().maxCoeff(); - eigen_assert(matrix.rows() == matrix.cols() && "PartialPivLU is only for square (and moreover invertible) matrices"); - const Index size = matrix.rows(); + eigen_assert(m_lu.rows() == m_lu.cols() && "PartialPivLU is only for square (and moreover invertible) matrices"); + const Index size = m_lu.rows(); m_rowsTranspositions.resize(size); @@ -493,7 +538,6 @@ PartialPivLU<MatrixType>& PartialPivLU<MatrixType>::compute(const EigenBase<Inpu m_p = m_rowsTranspositions; m_isInitialized = true; - return *this; } template<typename MatrixType> diff --git a/Eigen/src/QR/ColPivHouseholderQR.h b/Eigen/src/QR/ColPivHouseholderQR.h index e847bc434..db50b5675 100644 --- a/Eigen/src/QR/ColPivHouseholderQR.h +++ b/Eigen/src/QR/ColPivHouseholderQR.h @@ -51,7 +51,6 @@ template<typename _MatrixType> class ColPivHouseholderQR enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, ColsAtCompileTime = MatrixType::ColsAtCompileTime, - Options = MatrixType::Options, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime }; @@ -59,7 +58,6 @@ template<typename _MatrixType> class ColPivHouseholderQR typedef typename MatrixType::RealScalar RealScalar; // FIXME should be int typedef typename MatrixType::StorageIndex StorageIndex; - typedef Matrix<Scalar, RowsAtCompileTime, RowsAtCompileTime, Options, MaxRowsAtCompileTime, MaxRowsAtCompileTime> MatrixQType; typedef typename internal::plain_diag_type<MatrixType>::type HCoeffsType; typedef PermutationMatrix<ColsAtCompileTime, MaxColsAtCompileTime> PermutationType; typedef typename internal::plain_row_type<MatrixType, Index>::type IntRowVectorType; @@ -135,6 +133,27 @@ template<typename _MatrixType> class ColPivHouseholderQR compute(matrix.derived()); } + /** \brief Constructs a QR factorization from a given matrix + * + * This overloaded constructor is provided for inplace solving when \c MatrixType is a Eigen::Ref. + * + * \sa ColPivHouseholderQR(const EigenBase&) + */ + template<typename InputType> + explicit ColPivHouseholderQR(EigenBase<InputType>& matrix) + : m_qr(matrix.derived()), + m_hCoeffs((std::min)(matrix.rows(),matrix.cols())), + m_colsPermutation(PermIndexType(matrix.cols())), + m_colsTranspositions(matrix.cols()), + m_temp(matrix.cols()), + m_colNormsUpdated(matrix.cols()), + m_colNormsDirect(matrix.cols()), + m_isInitialized(false), + m_usePrescribedThreshold(false) + { + computeInPlace(); + } + /** This method finds a solution x to the equation Ax=b, where A is the matrix of which * *this is the QR decomposition, if any exists. * @@ -453,21 +472,19 @@ template<typename MatrixType> template<typename InputType> ColPivHouseholderQR<MatrixType>& ColPivHouseholderQR<MatrixType>::compute(const EigenBase<InputType>& matrix) { - check_template_parameters(); - - // the column permutation is stored as int indices, so just to be sure: - eigen_assert(matrix.cols()<=NumTraits<int>::highest()); - - m_qr = matrix; - + m_qr = matrix.derived(); computeInPlace(); - return *this; } template<typename MatrixType> void ColPivHouseholderQR<MatrixType>::computeInPlace() { + check_template_parameters(); + + // the column permutation is stored as int indices, so just to be sure: + eigen_assert(m_qr.cols()<=NumTraits<int>::highest()); + using std::abs; Index rows = m_qr.rows(); diff --git a/Eigen/src/QR/CompleteOrthogonalDecomposition.h b/Eigen/src/QR/CompleteOrthogonalDecomposition.h index 398e1aa77..967eb35dd 100644 --- a/Eigen/src/QR/CompleteOrthogonalDecomposition.h +++ b/Eigen/src/QR/CompleteOrthogonalDecomposition.h @@ -48,16 +48,12 @@ class CompleteOrthogonalDecomposition { enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, ColsAtCompileTime = MatrixType::ColsAtCompileTime, - Options = MatrixType::Options, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime }; typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::RealScalar RealScalar; typedef typename MatrixType::StorageIndex StorageIndex; - typedef Matrix<Scalar, RowsAtCompileTime, RowsAtCompileTime, Options, - MaxRowsAtCompileTime, MaxRowsAtCompileTime> - MatrixQType; typedef typename internal::plain_diag_type<MatrixType>::type HCoeffsType; typedef PermutationMatrix<ColsAtCompileTime, MaxColsAtCompileTime> PermutationType; @@ -114,10 +110,27 @@ class CompleteOrthogonalDecomposition { explicit CompleteOrthogonalDecomposition(const EigenBase<InputType>& matrix) : m_cpqr(matrix.rows(), matrix.cols()), m_zCoeffs((std::min)(matrix.rows(), matrix.cols())), - m_temp(matrix.cols()) { + m_temp(matrix.cols()) + { compute(matrix.derived()); } + /** \brief Constructs a complete orthogonal decomposition from a given matrix + * + * This overloaded constructor is provided for inplace solving when \c MatrixType is a Eigen::Ref. + * + * \sa CompleteOrthogonalDecomposition(const EigenBase&) + */ + template<typename InputType> + explicit CompleteOrthogonalDecomposition(EigenBase<InputType>& matrix) + : m_cpqr(matrix.derived()), + m_zCoeffs((std::min)(matrix.rows(), matrix.cols())), + m_temp(matrix.cols()) + { + computeInPlace(); + } + + /** This method computes the minimum-norm solution X to a least squares * problem \f[\mathrm{minimize} ||A X - B|| \f], where \b A is the matrix of * which \c *this is the complete orthogonal decomposition. @@ -165,7 +178,12 @@ class CompleteOrthogonalDecomposition { const MatrixType& matrixT() const { return m_cpqr.matrixQR(); } template <typename InputType> - CompleteOrthogonalDecomposition& compute(const EigenBase<InputType>& matrix); + CompleteOrthogonalDecomposition& compute(const EigenBase<InputType>& matrix) { + // Compute the column pivoted QR factorization A P = Q R. + m_cpqr.compute(matrix); + computeInPlace(); + return *this; + } /** \returns a const reference to the column permutation matrix */ const PermutationType& colsPermutation() const { @@ -354,6 +372,8 @@ class CompleteOrthogonalDecomposition { EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); } + void computeInPlace(); + /** Overwrites \b rhs with \f$ \mathbf{Z}^* * \mathbf{rhs} \f$. */ template <typename Rhs> @@ -384,20 +404,16 @@ CompleteOrthogonalDecomposition<MatrixType>::logAbsDeterminant() const { * CompleteOrthogonalDecomposition(const MatrixType&) */ template <typename MatrixType> -template <typename InputType> -CompleteOrthogonalDecomposition<MatrixType>& CompleteOrthogonalDecomposition< - MatrixType>::compute(const EigenBase<InputType>& matrix) { +void CompleteOrthogonalDecomposition<MatrixType>::computeInPlace() +{ check_template_parameters(); // the column permutation is stored as int indices, so just to be sure: - eigen_assert(matrix.cols() <= NumTraits<int>::highest()); - - // Compute the column pivoted QR factorization A P = Q R. - m_cpqr.compute(matrix); + eigen_assert(m_cpqr.cols() <= NumTraits<int>::highest()); const Index rank = m_cpqr.rank(); - const Index cols = matrix.cols(); - const Index rows = matrix.rows(); + const Index cols = m_cpqr.cols(); + const Index rows = m_cpqr.rows(); m_zCoeffs.resize((std::min)(rows, cols)); m_temp.resize(cols); @@ -443,7 +459,6 @@ CompleteOrthogonalDecomposition<MatrixType>& CompleteOrthogonalDecomposition< } } } - return *this; } template <typename MatrixType> diff --git a/Eigen/src/QR/FullPivHouseholderQR.h b/Eigen/src/QR/FullPivHouseholderQR.h index e21966056..5c2f57d04 100644 --- a/Eigen/src/QR/FullPivHouseholderQR.h +++ b/Eigen/src/QR/FullPivHouseholderQR.h @@ -60,7 +60,6 @@ template<typename _MatrixType> class FullPivHouseholderQR enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, ColsAtCompileTime = MatrixType::ColsAtCompileTime, - Options = MatrixType::Options, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime }; @@ -135,6 +134,26 @@ template<typename _MatrixType> class FullPivHouseholderQR compute(matrix.derived()); } + /** \brief Constructs a QR factorization from a given matrix + * + * This overloaded constructor is provided for inplace solving when \c MatrixType is a Eigen::Ref. + * + * \sa FullPivHouseholderQR(const EigenBase&) + */ + template<typename InputType> + explicit FullPivHouseholderQR(EigenBase<InputType>& matrix) + : m_qr(matrix.derived()), + m_hCoeffs((std::min)(matrix.rows(), matrix.cols())), + m_rows_transpositions((std::min)(matrix.rows(), matrix.cols())), + m_cols_transpositions((std::min)(matrix.rows(), matrix.cols())), + m_cols_permutation(matrix.cols()), + m_temp(matrix.cols()), + m_isInitialized(false), + m_usePrescribedThreshold(false) + { + computeInPlace(); + } + /** This method finds a solution x to the equation Ax=b, where A is the matrix of which * \c *this is the QR decomposition. * @@ -430,18 +449,16 @@ template<typename MatrixType> template<typename InputType> FullPivHouseholderQR<MatrixType>& FullPivHouseholderQR<MatrixType>::compute(const EigenBase<InputType>& matrix) { - check_template_parameters(); - m_qr = matrix.derived(); - computeInPlace(); - return *this; } template<typename MatrixType> void FullPivHouseholderQR<MatrixType>::computeInPlace() { + check_template_parameters(); + using std::abs; Index rows = m_qr.rows(); Index cols = m_qr.cols(); diff --git a/Eigen/src/QR/HouseholderQR.h b/Eigen/src/QR/HouseholderQR.h index 03bc8e6cd..f2a9cc080 100644 --- a/Eigen/src/QR/HouseholderQR.h +++ b/Eigen/src/QR/HouseholderQR.h @@ -47,7 +47,6 @@ template<typename _MatrixType> class HouseholderQR enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, ColsAtCompileTime = MatrixType::ColsAtCompileTime, - Options = MatrixType::Options, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime }; @@ -102,6 +101,24 @@ template<typename _MatrixType> class HouseholderQR compute(matrix.derived()); } + + /** \brief Constructs a QR factorization from a given matrix + * + * This overloaded constructor is provided for inplace solving when + * \c MatrixType is a Eigen::Ref. + * + * \sa HouseholderQR(const EigenBase&) + */ + template<typename InputType> + explicit HouseholderQR(EigenBase<InputType>& matrix) + : m_qr(matrix.derived()), + m_hCoeffs((std::min)(matrix.rows(),matrix.cols())), + m_temp(matrix.cols()), + m_isInitialized(false) + { + computeInPlace(); + } + /** This method finds a solution x to the equation Ax=b, where A is the matrix of which * *this is the QR decomposition, if any exists. * @@ -151,7 +168,11 @@ template<typename _MatrixType> class HouseholderQR } template<typename InputType> - HouseholderQR& compute(const EigenBase<InputType>& matrix); + HouseholderQR& compute(const EigenBase<InputType>& matrix) { + m_qr = matrix.derived(); + computeInPlace(); + return *this; + } /** \returns the absolute value of the determinant of the matrix of which * *this is the QR decomposition. It has only linear complexity @@ -203,6 +224,8 @@ template<typename _MatrixType> class HouseholderQR { EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); } + + void computeInPlace(); MatrixType m_qr; HCoeffsType m_hCoeffs; @@ -354,16 +377,14 @@ void HouseholderQR<_MatrixType>::_solve_impl(const RhsType &rhs, DstType &dst) c * \sa class HouseholderQR, HouseholderQR(const MatrixType&) */ template<typename MatrixType> -template<typename InputType> -HouseholderQR<MatrixType>& HouseholderQR<MatrixType>::compute(const EigenBase<InputType>& matrix) +void HouseholderQR<MatrixType>::computeInPlace() { check_template_parameters(); - Index rows = matrix.rows(); - Index cols = matrix.cols(); + Index rows = m_qr.rows(); + Index cols = m_qr.cols(); Index size = (std::min)(rows,cols); - m_qr = matrix.derived(); m_hCoeffs.resize(size); m_temp.resize(cols); @@ -371,7 +392,6 @@ HouseholderQR<MatrixType>& HouseholderQR<MatrixType>::compute(const EigenBase<In internal::householder_qr_inplace_blocked<MatrixType, HCoeffsType>::run(m_qr, m_hCoeffs, 48, m_temp.data()); m_isInitialized = true; - return *this; } #ifndef __CUDACC__ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9d49f1e97..27a9ec4ac 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -258,6 +258,7 @@ ei_add_test(rvalue_types) ei_add_test(dense_storage) ei_add_test(ctorleak) ei_add_test(mpl2only) +ei_add_test(inplace_decomposition) add_executable(bug1213 bug1213.cpp bug1213_main.cpp) diff --git a/test/inplace_decomposition.cpp b/test/inplace_decomposition.cpp new file mode 100644 index 000000000..4900312f1 --- /dev/null +++ b/test/inplace_decomposition.cpp @@ -0,0 +1,110 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" +#include <Eigen/LU> +#include <Eigen/Cholesky> +#include <Eigen/QR> + +// This file test inplace decomposition through Ref<>, as supported by Cholesky, LU, and QR decompositions. + +template<typename DecType,typename MatrixType> void inplace(bool square = false, bool SPD = false) +{ + typedef typename MatrixType::Scalar Scalar; + typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> RhsType; + typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, 1> ResType; + + Index rows = MatrixType::RowsAtCompileTime==Dynamic ? internal::random<Index>(2,EIGEN_TEST_MAX_SIZE/2) : MatrixType::RowsAtCompileTime; + Index cols = MatrixType::ColsAtCompileTime==Dynamic ? (square?rows:internal::random<Index>(2,rows)) : MatrixType::ColsAtCompileTime; + + MatrixType A = MatrixType::Random(rows,cols); + RhsType b = RhsType::Random(rows); + ResType x(cols); + + if(SPD) + { + assert(square); + A.topRows(cols) = A.topRows(cols).adjoint() * A.topRows(cols); + A.diagonal().array() += 1e-3; + } + + MatrixType A0 = A; + MatrixType A1 = A; + + DecType dec(A); + + // Check that the content of A has been modified + VERIFY_IS_NOT_APPROX( A, A0 ); + + // Check that the decomposition is correct: + if(rows==cols) + { + VERIFY_IS_APPROX( A0 * (x = dec.solve(b)), b ); + } + else + { + VERIFY_IS_APPROX( A0.transpose() * A0 * (x = dec.solve(b)), A0.transpose() * b ); + } + + // Check that modifying A breaks the current dec: + A.setRandom(); + if(rows==cols) + { + VERIFY_IS_NOT_APPROX( A0 * (x = dec.solve(b)), b ); + } + else + { + VERIFY_IS_NOT_APPROX( A0.transpose() * A0 * (x = dec.solve(b)), A0.transpose() * b ); + } + + // Check that calling compute(A1) does not modify A1: + A = A0; + dec.compute(A1); + VERIFY_IS_EQUAL(A0,A1); + VERIFY_IS_NOT_APPROX( A, A0 ); + if(rows==cols) + { + VERIFY_IS_APPROX( A0 * (x = dec.solve(b)), b ); + } + else + { + VERIFY_IS_APPROX( A0.transpose() * A0 * (x = dec.solve(b)), A0.transpose() * b ); + } +} + + +void test_inplace_decomposition() +{ + EIGEN_UNUSED typedef Matrix<double,4,3> Matrix43d; + for(int i = 0; i < g_repeat; i++) { + CALL_SUBTEST_1(( inplace<LLT<Ref<MatrixXd> >, MatrixXd>(true,true) )); + CALL_SUBTEST_1(( inplace<LLT<Ref<Matrix4d> >, Matrix4d>(true,true) )); + + CALL_SUBTEST_2(( inplace<LDLT<Ref<MatrixXd> >, MatrixXd>(true,true) )); + CALL_SUBTEST_2(( inplace<LDLT<Ref<Matrix4d> >, Matrix4d>(true,true) )); + + CALL_SUBTEST_3(( inplace<PartialPivLU<Ref<MatrixXd> >, MatrixXd>(true,false) )); + CALL_SUBTEST_3(( inplace<PartialPivLU<Ref<Matrix4d> >, Matrix4d>(true,false) )); + + CALL_SUBTEST_4(( inplace<FullPivLU<Ref<MatrixXd> >, MatrixXd>(true,false) )); + CALL_SUBTEST_4(( inplace<FullPivLU<Ref<Matrix4d> >, Matrix4d>(true,false) )); + + CALL_SUBTEST_5(( inplace<HouseholderQR<Ref<MatrixXd> >, MatrixXd>(false,false) )); + CALL_SUBTEST_5(( inplace<HouseholderQR<Ref<Matrix43d> >, Matrix43d>(false,false) )); + + CALL_SUBTEST_6(( inplace<ColPivHouseholderQR<Ref<MatrixXd> >, MatrixXd>(false,false) )); + CALL_SUBTEST_6(( inplace<ColPivHouseholderQR<Ref<Matrix43d> >, Matrix43d>(false,false) )); + + CALL_SUBTEST_7(( inplace<FullPivHouseholderQR<Ref<MatrixXd> >, MatrixXd>(false,false) )); + CALL_SUBTEST_7(( inplace<FullPivHouseholderQR<Ref<Matrix43d> >, Matrix43d>(false,false) )); + + CALL_SUBTEST_8(( inplace<CompleteOrthogonalDecomposition<Ref<MatrixXd> >, MatrixXd>(false,false) )); + CALL_SUBTEST_8(( inplace<CompleteOrthogonalDecomposition<Ref<Matrix43d> >, Matrix43d>(false,false) )); + } +} |