diff options
author | Desire NUENTSA <desire.nuentsa_wakam@inria.fr> | 2013-05-21 17:35:10 +0200 |
---|---|---|
committer | Desire NUENTSA <desire.nuentsa_wakam@inria.fr> | 2013-05-21 17:35:10 +0200 |
commit | cf939f154fe1ec9904276f5f3f2dbdbc8e98156a (patch) | |
tree | 7f68884a51d66a3bc1272c660f6436b5e8ecacda /Eigen/src/SparseQR | |
parent | bd7511fc3651c70d20473dc0c6beab3013fd229a (diff) |
Fix bug #596 : Recover plain SparseMatrix from SparseQR matrixQ()
Diffstat (limited to 'Eigen/src/SparseQR')
-rw-r--r-- | Eigen/src/SparseQR/SparseQR.h | 118 |
1 files changed, 93 insertions, 25 deletions
diff --git a/Eigen/src/SparseQR/SparseQR.h b/Eigen/src/SparseQR/SparseQR.h index b3d5cd208..0c70609b4 100644 --- a/Eigen/src/SparseQR/SparseQR.h +++ b/Eigen/src/SparseQR/SparseQR.h @@ -21,6 +21,8 @@ namespace internal { template <typename SparseQRType> struct traits<SparseQRMatrixQReturnType<SparseQRType> > { typedef typename SparseQRType::MatrixType ReturnType; + typedef typename ReturnType::Index Index; + typedef typename ReturnType::StorageKind StorageKind; }; template <typename SparseQRType> struct traits<SparseQRMatrixQTransposeReturnType<SparseQRType> > { @@ -72,10 +74,10 @@ class SparseQR typedef Matrix<Scalar, Dynamic, 1> ScalarVector; typedef PermutationMatrix<Dynamic, Dynamic, Index> PermutationType; public: - SparseQR () : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true) + SparseQR () : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true),m_isQSorted(false) { } - SparseQR(const MatrixType& mat) : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true) + SparseQR(const MatrixType& mat) : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true),m_isQSorted(false) { compute(mat); } @@ -110,11 +112,23 @@ class SparseQR } /** \returns an expression of the matrix Q as products of sparse Householder reflectors. - * You can do the following to get an actual SparseMatrix representation of Q: - * \code - * SparseMatrix<double> Q = SparseQR<SparseMatrix<double> >(A).matrixQ(); - * \endcode - */ + * The common usage of this function is to apply it to a dense matrix or vector + * \code + * VectorXd B1, B2; + * // Initialize B1 + * B2 = matrixQ() * B1; + * \endcode + * + * To get a plain SparseMatrix representation of Q: + * \code + * SparseMatrix<double> Q; + * Q = SparseQR<SparseMatrix<double> >(A).matrixQ(); + * \endcode + * Internally, this call simply performs a sparse product between the matrix Q + * and a sparse identity matrix. However, due to the fact that the sparse + * reflectors are stored unsorted, two transpositions are needed to sort + * them before performing the product. + */ SparseQRMatrixQReturnType<SparseQR> matrixQ() const { return SparseQRMatrixQReturnType<SparseQR>(*this); } @@ -158,6 +172,7 @@ class SparseQR return true; } + /** Sets the threshold that is used to determine linearly dependent columns during the factorization. * * In practice, if during the factorization the norm of the column that has to be eliminated is below @@ -180,6 +195,13 @@ class SparseQR eigen_assert(this->rows() == B.rows() && "SparseQR::solve() : invalid number of rows in the right hand side matrix"); return internal::solve_retval<SparseQR, Rhs>(*this, B.derived()); } + template<typename Rhs> + inline const internal::sparse_solve_retval<SparseQR, Rhs> solve(const SparseMatrixBase<Rhs>& B) const + { + eigen_assert(m_isInitialized && "The factorization should be called first, use compute()"); + eigen_assert(this->rows() == B.rows() && "SparseQR::solve() : invalid number of rows in the right hand side matrix"); + return internal::sparse_solve_retval<SparseQR, Rhs>(*this, B.derived()); + } /** \brief Reports whether previous computation was successful. * @@ -194,6 +216,16 @@ class SparseQR eigen_assert(m_isInitialized && "Decomposition is not initialized."); return m_info; } + + protected: + inline void sort_matrix_Q() + { + // The matrix Q is sorted during the transposition + SparseMatrix<Scalar, RowMajor, Index> mQrm(this->m_Q); + this->m_Q = mQrm; + this->m_isQSorted = true; + } + protected: bool m_isInitialized; @@ -213,8 +245,10 @@ class SparseQR Index m_nonzeropivots; // Number of non zero pivots found IndexVector m_etree; // Column elimination tree IndexVector m_firstRowElt; // First element in each row + bool m_isQSorted; // whether Q is sorted or not template <typename, typename > friend struct SparseQR_QProduct; + template <typename > friend struct SparseQRMatrixQReturnType; }; @@ -462,6 +496,7 @@ void SparseQR<MatrixType,OrderingType>::factorize(const MatrixType& mat) m_Q.makeCompressed(); m_R.finalize(); m_R.makeCompressed(); + m_isQSorted = false; m_nonzeropivots = nonzeroCol; @@ -494,7 +529,18 @@ struct solve_retval<SparseQR<_MatrixType,OrderingType>, Rhs> dec()._solve(rhs(),dst); } }; +template<typename _MatrixType, typename OrderingType, typename Rhs> +struct sparse_solve_retval<SparseQR<_MatrixType, OrderingType>, Rhs> + : sparse_solve_retval_base<SparseQR<_MatrixType, OrderingType>, Rhs> +{ + typedef SparseQR<_MatrixType, OrderingType> Dec; + EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec, Rhs) + template<typename Dest> void evalTo(Dest& dst) const + { + this->defaultEvalTo(dst); + } +}; } // end namespace internal template <typename SparseQRType, typename Derived> @@ -513,34 +559,35 @@ struct SparseQR_QProduct : ReturnByValue<SparseQR_QProduct<SparseQRType, Derived template<typename DesType> void evalTo(DesType& res) const { - Index n = m_qr.cols(); + Index n = m_qr.cols(); + res = m_other; if (m_transpose) { eigen_assert(m_qr.m_Q.rows() == m_other.rows() && "Non conforming object sizes"); - // Compute res = Q' * other : - res = m_other; - for (Index k = 0; k < n; k++) - { - Scalar tau = Scalar(0); - tau = m_qr.m_Q.col(k).dot(res); - tau = tau * m_qr.m_hcoeffs(k); - for (typename MatrixType::InnerIterator itq(m_qr.m_Q, k); itq; ++itq) + //Compute res = Q' * other column by column + for(Index j = 0; j < res.cols(); j++){ + for (Index k = 0; k < n; k++) { - res(itq.row()) -= itq.value() * tau; + Scalar tau = Scalar(0); + tau = m_qr.m_Q.col(k).dot(res.col(j)); + tau = tau * m_qr.m_hcoeffs(k); + res.col(j) -= tau * m_qr.m_Q.col(k); } } } else { eigen_assert(m_qr.m_Q.cols() == m_other.rows() && "Non conforming object sizes"); - // Compute res = Q * other : - res = m_other; - for (Index k = n-1; k >=0; k--) + // Compute res = Q' * other column by column + for(Index j = 0; j < res.cols(); j++) { - Scalar tau = Scalar(0); - tau = m_qr.m_Q.col(k).dot(res); - tau = tau * m_qr.m_hcoeffs(k); - res -= tau * m_qr.m_Q.col(k); + for (Index k = n-1; k >=0; k--) + { + Scalar tau = Scalar(0); + tau = m_qr.m_Q.col(k).dot(res.col(j)); + tau = tau * m_qr.m_hcoeffs(k); + res.col(j) -= tau * m_qr.m_Q.col(k); + } } } } @@ -551,8 +598,11 @@ struct SparseQR_QProduct : ReturnByValue<SparseQR_QProduct<SparseQRType, Derived }; template<typename SparseQRType> -struct SparseQRMatrixQReturnType +struct SparseQRMatrixQReturnType : public EigenBase<SparseQRMatrixQReturnType<SparseQRType> > { + typedef typename SparseQRType::Index Index; + typedef typename SparseQRType::Scalar Scalar; + typedef Matrix<Scalar,Dynamic,Dynamic> DenseMatrix; SparseQRMatrixQReturnType(const SparseQRType& qr) : m_qr(qr) {} template<typename Derived> SparseQR_QProduct<SparseQRType, Derived> operator*(const MatrixBase<Derived>& other) @@ -563,11 +613,29 @@ struct SparseQRMatrixQReturnType { return SparseQRMatrixQTransposeReturnType<SparseQRType>(m_qr); } + inline Index rows() const { return m_qr.rows(); } + inline Index cols() const { return m_qr.cols(); } // To use for operations with the transpose of Q SparseQRMatrixQTransposeReturnType<SparseQRType> transpose() const { return SparseQRMatrixQTransposeReturnType<SparseQRType>(m_qr); } + template<typename Dest> void evalTo(MatrixBase<Dest>& dest) const + { + dest.resize(m_qr.rows(), m_qr.cols()); + dest.derived() = m_qr.matrixQ() * Dest::Identity(m_qr.rows(), m_qr.rows()); + } + template<typename Dest> void evalTo(SparseMatrixBase<Dest>& dest) const + { + Dest idMat(m_qr.rows(), m_qr.rows()); + idMat.setIdentity(); + dest.derived().resize(m_qr.rows(), m_qr.cols()); + // Sort the sparse householder reflectors if needed + if(!m_qr.m_isQSorted) + const_cast<SparseQRType *>(&m_qr)->sort_matrix_Q(); + dest.derived() = SparseQR_QProduct<SparseQRType, Dest>(m_qr, idMat, false); + } + const SparseQRType& m_qr; }; |