diff options
author | Patrick Peltzer <peltzer@stce.rwth-aachen.de> | 2019-01-17 01:17:39 +0100 |
---|---|---|
committer | Patrick Peltzer <peltzer@stce.rwth-aachen.de> | 2019-01-17 01:17:39 +0100 |
commit | 15e53d5d93bd79fa415416d3f979975f0014a64d (patch) | |
tree | ccc062d964f707c9c1c250965490d87fbc145885 /Eigen/src/QR/CompleteOrthogonalDecomposition.h | |
parent | 7f32109c11b9cbc3cedc72e59683bf5839d35d75 (diff) |
PR 567: makes all dense solvers inherit SoverBase (LU,Cholesky,QR,SVD).
This changeset also includes:
* add HouseholderSequence::conjugateIf
* define int as the StorageIndex type for all dense solvers
* dedicated unit tests, including assertion checking
* _check_solve_assertion(): this method can be implemented in derived solver classes to implement custom checks
* CompleteOrthogonalDecompositions: add applyZOnTheLeftInPlace, fix scalar type in applyZAdjointOnTheLeftInPlace(), add missing assertions
* Cholesky: add missing assertions
* FullPivHouseholderQR: Corrected Scalar type in _solve_impl()
* BDCSVD: Unambiguous return type for ternary operator
* SVDBase: Corrected Scalar type in _solve_impl()
Diffstat (limited to 'Eigen/src/QR/CompleteOrthogonalDecomposition.h')
-rw-r--r-- | Eigen/src/QR/CompleteOrthogonalDecomposition.h | 105 |
1 files changed, 86 insertions, 19 deletions
diff --git a/Eigen/src/QR/CompleteOrthogonalDecomposition.h b/Eigen/src/QR/CompleteOrthogonalDecomposition.h index 03017a375..d62628087 100644 --- a/Eigen/src/QR/CompleteOrthogonalDecomposition.h +++ b/Eigen/src/QR/CompleteOrthogonalDecomposition.h @@ -16,6 +16,9 @@ namespace internal { template <typename _MatrixType> struct traits<CompleteOrthogonalDecomposition<_MatrixType> > : traits<_MatrixType> { + typedef MatrixXpr XprKind; + typedef SolverStorage StorageKind; + typedef int StorageIndex; enum { Flags = 0 }; }; @@ -44,19 +47,21 @@ struct traits<CompleteOrthogonalDecomposition<_MatrixType> > * * \sa MatrixBase::completeOrthogonalDecomposition() */ -template <typename _MatrixType> -class CompleteOrthogonalDecomposition { +template <typename _MatrixType> class CompleteOrthogonalDecomposition + : public SolverBase<CompleteOrthogonalDecomposition<_MatrixType> > +{ public: typedef _MatrixType MatrixType; + typedef SolverBase<CompleteOrthogonalDecomposition> Base; + + template<typename Derived> + friend struct internal::solve_assertion; + + EIGEN_GENERIC_PUBLIC_INTERFACE(CompleteOrthogonalDecomposition) enum { - RowsAtCompileTime = MatrixType::RowsAtCompileTime, - ColsAtCompileTime = MatrixType::ColsAtCompileTime, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime }; - typedef typename MatrixType::Scalar Scalar; - typedef typename MatrixType::RealScalar RealScalar; - typedef typename MatrixType::StorageIndex StorageIndex; typedef typename internal::plain_diag_type<MatrixType>::type HCoeffsType; typedef PermutationMatrix<ColsAtCompileTime, MaxColsAtCompileTime> PermutationType; @@ -131,9 +136,9 @@ class CompleteOrthogonalDecomposition { m_temp(matrix.cols()) { computeInPlace(); - } - + } + #ifdef EIGEN_PARSED_BY_DOXYGEN /** This method computes the minimum-norm solution X to a least squares * problem \f[\mathrm{minimize} \|A X - B\|, \f] where \b A is the matrix of * which \c *this is the complete orthogonal decomposition. @@ -145,11 +150,8 @@ class CompleteOrthogonalDecomposition { */ template <typename Rhs> inline const Solve<CompleteOrthogonalDecomposition, Rhs> solve( - const MatrixBase<Rhs>& b) const { - eigen_assert(m_cpqr.m_isInitialized && - "CompleteOrthogonalDecomposition is not initialized."); - return Solve<CompleteOrthogonalDecomposition, Rhs>(*this, b.derived()); - } + const MatrixBase<Rhs>& b) const; + #endif HouseholderSequenceType householderQ(void) const; HouseholderSequenceType matrixQ(void) const { return m_cpqr.householderQ(); } @@ -158,8 +160,8 @@ class CompleteOrthogonalDecomposition { */ MatrixType matrixZ() const { MatrixType Z = MatrixType::Identity(m_cpqr.cols(), m_cpqr.cols()); - applyZAdjointOnTheLeftInPlace(Z); - return Z.adjoint(); + applyZOnTheLeftInPlace<false>(Z); + return Z; } /** \returns a reference to the matrix where the complete orthogonal @@ -275,6 +277,7 @@ class CompleteOrthogonalDecomposition { */ inline const Inverse<CompleteOrthogonalDecomposition> pseudoInverse() const { + eigen_assert(m_cpqr.m_isInitialized && "CompleteOrthogonalDecomposition is not initialized."); return Inverse<CompleteOrthogonalDecomposition>(*this); } @@ -368,6 +371,9 @@ class CompleteOrthogonalDecomposition { #ifndef EIGEN_PARSED_BY_DOXYGEN template <typename RhsType, typename DstType> void _solve_impl(const RhsType& rhs, DstType& dst) const; + + template<bool Conjugate, typename RhsType, typename DstType> + void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const; #endif protected: @@ -375,8 +381,21 @@ class CompleteOrthogonalDecomposition { EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); } + template<bool Transpose_, typename Rhs> + void _check_solve_assertion(const Rhs& b) const { + eigen_assert(m_cpqr.m_isInitialized && "CompleteOrthogonalDecomposition is not initialized."); + eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "CompleteOrthogonalDecomposition::solve(): invalid number of rows of the right hand side matrix b"); + } + void computeInPlace(); + /** Overwrites \b rhs with \f$ \mathbf{Z} * \mathbf{rhs} \f$ or + * \f$ \mathbf{\overline Z} * \mathbf{rhs} \f$ if \c Conjugate + * is set to \c true. + */ + template <bool Conjugate, typename Rhs> + void applyZOnTheLeftInPlace(Rhs& rhs) const; + /** Overwrites \b rhs with \f$ \mathbf{Z}^* * \mathbf{rhs} \f$. */ template <typename Rhs> @@ -465,13 +484,35 @@ void CompleteOrthogonalDecomposition<MatrixType>::computeInPlace() } template <typename MatrixType> +template <bool Conjugate, typename Rhs> +void CompleteOrthogonalDecomposition<MatrixType>::applyZOnTheLeftInPlace( + Rhs& rhs) const { + const Index cols = this->cols(); + const Index nrhs = rhs.cols(); + const Index rank = this->rank(); + Matrix<typename Rhs::Scalar, Dynamic, 1> temp((std::max)(cols, nrhs)); + for (Index k = rank-1; k >= 0; --k) { + if (k != rank - 1) { + rhs.row(k).swap(rhs.row(rank - 1)); + } + rhs.middleRows(rank - 1, cols - rank + 1) + .applyHouseholderOnTheLeft( + matrixQTZ().row(k).tail(cols - rank).transpose().template conjugateIf<!Conjugate>(), zCoeffs().template conjugateIf<Conjugate>()(k), + &temp(0)); + if (k != rank - 1) { + rhs.row(k).swap(rhs.row(rank - 1)); + } + } +} + +template <typename MatrixType> template <typename Rhs> void CompleteOrthogonalDecomposition<MatrixType>::applyZAdjointOnTheLeftInPlace( Rhs& rhs) const { const Index cols = this->cols(); const Index nrhs = rhs.cols(); const Index rank = this->rank(); - Matrix<typename MatrixType::Scalar, Dynamic, 1> temp((std::max)(cols, nrhs)); + Matrix<typename Rhs::Scalar, Dynamic, 1> temp((std::max)(cols, nrhs)); for (Index k = 0; k < rank; ++k) { if (k != rank - 1) { rhs.row(k).swap(rhs.row(rank - 1)); @@ -491,8 +532,6 @@ template <typename _MatrixType> template <typename RhsType, typename DstType> void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl( const RhsType& rhs, DstType& dst) const { - eigen_assert(rhs.rows() == this->rows()); - const Index rank = this->rank(); if (rank == 0) { dst.setZero(); @@ -520,6 +559,34 @@ void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl( // Undo permutation to get x = P^{-1} * y. dst = colsPermutation() * dst; } + +template<typename _MatrixType> +template<bool Conjugate, typename RhsType, typename DstType> +void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const +{ + const Index rank = this->rank(); + + if (rank == 0) { + dst.setZero(); + return; + } + + typename RhsType::PlainObject c(colsPermutation().transpose()*rhs); + + if (rank < cols()) { + applyZOnTheLeftInPlace<!Conjugate>(c); + } + + matrixT().topLeftCorner(rank, rank) + .template triangularView<Upper>() + .transpose().template conjugateIf<Conjugate>() + .solveInPlace(c.topRows(rank)); + + dst.topRows(rank) = c.topRows(rank); + dst.bottomRows(rows()-rank).setZero(); + + dst.applyOnTheLeft(householderQ().setLength(rank).template conjugateIf<!Conjugate>() ); +} #endif namespace internal { |