diff options
author | Benoit Jacob <jacob.benoit.1@gmail.com> | 2009-10-30 08:51:33 -0400 |
---|---|---|
committer | Benoit Jacob <jacob.benoit.1@gmail.com> | 2009-10-30 08:51:33 -0400 |
commit | f975b9bd3eb0a862efef290a63a3d1d20a03c099 (patch) | |
tree | 745447790a6135462682b381affaecd73fd4fe39 /Eigen | |
parent | 6b48e932e9b68159d2b0cc9d0d14c4025808327c (diff) |
SVD::solve() : port to new API and improvements
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/src/LU/FullPivLU.h | 2 | ||||
-rw-r--r-- | Eigen/src/SVD/SVD.h | 124 |
2 files changed, 87 insertions, 39 deletions
diff --git a/Eigen/src/LU/FullPivLU.h b/Eigen/src/LU/FullPivLU.h index a28a536b6..067b59549 100644 --- a/Eigen/src/LU/FullPivLU.h +++ b/Eigen/src/LU/FullPivLU.h @@ -200,7 +200,7 @@ template<typename MatrixType> class FullPivLU return ei_fullpivlu_image_impl<MatrixType>(*this, originalMatrix.derived()); } - /** This method returns a solution x to the equation Ax=b, where A is the matrix of which + /** \return a solution x to the equation Ax=b, where A is the matrix of which * *this is the LU decomposition. * * \param b the right-hand-side of the equation to solve. Can be a vector or a matrix, diff --git a/Eigen/src/SVD/SVD.h b/Eigen/src/SVD/SVD.h index da01cf396..807e7058c 100644 --- a/Eigen/src/SVD/SVD.h +++ b/Eigen/src/SVD/SVD.h @@ -25,6 +25,8 @@ #ifndef EIGEN_SVD_H #define EIGEN_SVD_H +template<typename MatrixType, typename Rhs> struct ei_svd_solve_impl; + /** \ingroup SVD_Module * \nonstableyet * @@ -40,24 +42,24 @@ */ template<typename MatrixType> class SVD { - private: + public: typedef typename MatrixType::Scalar Scalar; typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar; enum { + RowsAtCompileTime = MatrixType::RowsAtCompileTime, + ColsAtCompileTime = MatrixType::ColsAtCompileTime, PacketSize = ei_packet_traits<Scalar>::size, AlignmentMask = int(PacketSize)-1, - MinSize = EIGEN_ENUM_MIN(MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime) + MinSize = EIGEN_ENUM_MIN(RowsAtCompileTime, ColsAtCompileTime) }; - typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVector; - typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, 1> RowVector; + typedef Matrix<Scalar, RowsAtCompileTime, 1> ColVector; + typedef Matrix<Scalar, ColsAtCompileTime, 1> RowVector; - typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> MatrixUType; - typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, MatrixType::ColsAtCompileTime> MatrixVType; - typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, 1> SingularValuesType; - - public: + typedef Matrix<Scalar, RowsAtCompileTime, RowsAtCompileTime> MatrixUType; + typedef Matrix<Scalar, ColsAtCompileTime, ColsAtCompileTime> MatrixVType; + typedef Matrix<Scalar, ColsAtCompileTime, 1> SingularValuesType; /** * \brief Default Constructor. @@ -76,8 +78,24 @@ template<typename MatrixType> class SVD compute(matrix); } - template<typename OtherDerived, typename ResultType> - bool solve(const MatrixBase<OtherDerived> &b, ResultType* result) const; + /** \returns a solution of \f$ A x = b \f$ using the current SVD decomposition of A. + * + * \param b the right-hand-side of the equation to solve. + * + * \note_about_checking_solutions + * + * \note_about_arbitrary_choice_of_solution + * \note_about_using_kernel_to_study_multiple_solutions + * + * \sa MatrixBase::svd(), + */ + template<typename Rhs> + inline const ei_svd_solve_impl<MatrixType, Rhs> + solve(const MatrixBase<Rhs>& b) const + { + ei_assert(m_isInitialized && "SVD is not initialized."); + return ei_svd_solve_impl<MatrixType, Rhs>(*this, b.derived()); + } const MatrixUType& matrixU() const { @@ -108,6 +126,18 @@ template<typename MatrixType> class SVD template<typename ScalingType, typename RotationType> void computeScalingRotation(ScalingType *positive, RotationType *unitary) const; + inline int rows() const + { + ei_assert(m_isInitialized && "SVD is not initialized."); + return m_rows; + } + + inline int cols() const + { + ei_assert(m_isInitialized && "SVD is not initialized."); + return m_cols; + } + protected: // Computes (a^2 + b^2)^(1/2) without destructive underflow or overflow. inline static Scalar pythag(Scalar a, Scalar b) @@ -133,6 +163,7 @@ template<typename MatrixType> class SVD /** \internal */ SingularValuesType m_sigma; bool m_isInitialized; + int m_rows, m_cols; }; /** Computes / recomputes the SVD decomposition A = U S V^* of \a matrix @@ -144,8 +175,8 @@ template<typename MatrixType> class SVD template<typename MatrixType> SVD<MatrixType>& SVD<MatrixType>::compute(const MatrixType& matrix) { - const int m = matrix.rows(); - const int n = matrix.cols(); + const int m = m_rows = matrix.rows(); + const int n = m_cols = matrix.cols(); m_matU.resize(m, m); m_matU.setZero(); @@ -397,40 +428,57 @@ SVD<MatrixType>& SVD<MatrixType>::compute(const MatrixType& matrix) return *this; } -/** \returns the solution of \f$ A x = b \f$ using the current SVD decomposition of A. - * The parts of the solution corresponding to zero singular values are ignored. - * - * \sa MatrixBase::svd(), LU::solve(), LLT::solve() - */ -template<typename MatrixType> -template<typename OtherDerived, typename ResultType> -bool SVD<MatrixType>::solve(const MatrixBase<OtherDerived> &b, ResultType* result) const +template<typename MatrixType,typename Rhs> +struct ei_traits<ei_svd_solve_impl<MatrixType,Rhs> > { - ei_assert(m_isInitialized && "SVD is not initialized."); + typedef Matrix<typename Rhs::Scalar, + MatrixType::ColsAtCompileTime, + Rhs::ColsAtCompileTime, + Rhs::PlainMatrixType::Options, + MatrixType::MaxColsAtCompileTime, + Rhs::MaxColsAtCompileTime> ReturnMatrixType; +}; - const int rows = m_matU.rows(); - ei_assert(b.rows() == rows); +template<typename MatrixType, typename Rhs> +struct ei_svd_solve_impl : public ReturnByValue<ei_svd_solve_impl<MatrixType, Rhs> > +{ + typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested; + typedef SVD<MatrixType> SVDType; + typedef typename MatrixType::RealScalar RealScalar; + typedef typename MatrixType::Scalar Scalar; + const SVDType& m_svd; + const typename Rhs::Nested m_rhs; + + ei_svd_solve_impl(const SVDType& svd, const Rhs& rhs) + : m_svd(svd), m_rhs(rhs) + {} - result->resize(m_matV.rows(), b.cols()); + inline int rows() const { return m_svd.cols(); } + inline int cols() const { return m_rhs.cols(); } - Scalar maxVal = m_sigma.cwise().abs().maxCoeff(); - for (int j=0; j<b.cols(); ++j) + template<typename Dest> void evalTo(Dest& dst) const { - Matrix<Scalar,MatrixUType::RowsAtCompileTime,1> aux = m_matU.transpose() * b.col(j); + ei_assert(m_rhs.rows() == m_svd.rows()); - for (int i = 0; i <m_matU.cols(); ++i) + dst.resize(rows(), cols()); + + for (int j=0; j<cols(); ++j) { - Scalar si = m_sigma.coeff(i); - if (ei_isMuchSmallerThan(ei_abs(si),maxVal)) - aux.coeffRef(i) = 0; - else - aux.coeffRef(i) /= si; - } + Matrix<Scalar,SVDType::RowsAtCompileTime,1> aux = m_svd.matrixU().adjoint() * m_rhs.col(j); + + for (int i = 0; i <m_svd.rows(); ++i) + { + Scalar si = m_svd.singularValues().coeff(i); + if(si == RealScalar(0)) + aux.coeffRef(i) = Scalar(0); + else + aux.coeffRef(i) /= si; + } - result->col(j) = m_matV * aux; + dst.col(j) = m_svd.matrixV() * aux; + } } - return true; -} +}; /** Computes the polar decomposition of the matrix, as a product unitary x positive. * |