diff options
Diffstat (limited to 'Eigen/src/SVD/SVDBase.h')
-rw-r--r-- | Eigen/src/SVD/SVDBase.h | 63 |
1 files changed, 49 insertions, 14 deletions
diff --git a/Eigen/src/SVD/SVDBase.h b/Eigen/src/SVD/SVDBase.h index 851ad6836..ed1e9f20e 100644 --- a/Eigen/src/SVD/SVDBase.h +++ b/Eigen/src/SVD/SVDBase.h @@ -17,6 +17,18 @@ #define EIGEN_SVDBASE_H namespace Eigen { + +namespace internal { +template<typename Derived> struct traits<SVDBase<Derived> > + : traits<Derived> +{ + typedef MatrixXpr XprKind; + typedef SolverStorage StorageKind; + typedef int StorageIndex; + enum { Flags = 0 }; +}; +} + /** \ingroup SVD_Module * * @@ -44,15 +56,18 @@ namespace Eigen { * terminate in finite (and reasonable) time. * \sa class BDCSVD, class JacobiSVD */ -template<typename Derived> -class SVDBase +template<typename Derived> class SVDBase + : public SolverBase<SVDBase<Derived> > { +public: + + template<typename Derived_> + friend struct internal::solve_assertion; -public: typedef typename internal::traits<Derived>::MatrixType MatrixType; typedef typename MatrixType::Scalar Scalar; typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar; - typedef typename MatrixType::StorageIndex StorageIndex; + typedef typename Eigen::internal::traits<SVDBase>::StorageIndex StorageIndex; typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3 enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, @@ -194,6 +209,7 @@ public: inline Index rows() const { return m_rows; } inline Index cols() const { return m_cols; } + #ifdef EIGEN_PARSED_BY_DOXYGEN /** \returns a (least squares) solution of \f$ A x = b \f$ using the current SVD decomposition of A. * * \param b the right-hand-side of the equation to solve. @@ -205,16 +221,15 @@ public: */ template<typename Rhs> inline const Solve<Derived, Rhs> - solve(const MatrixBase<Rhs>& b) const - { - eigen_assert(m_isInitialized && "SVD is not initialized."); - eigen_assert(computeU() && computeV() && "SVD::solve() requires both unitaries U and V to be computed (thin unitaries suffice)."); - return Solve<Derived, Rhs>(derived(), b.derived()); - } - + solve(const MatrixBase<Rhs>& b) const; + #endif + #ifndef EIGEN_PARSED_BY_DOXYGEN template<typename RhsType, typename DstType> void _solve_impl(const RhsType &rhs, DstType &dst) const; + + template<bool Conjugate, typename RhsType, typename DstType> + void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const; #endif protected: @@ -223,6 +238,13 @@ protected: { EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); } + + template<bool Transpose_, typename Rhs> + void _check_solve_assertion(const Rhs& b) const { + eigen_assert(m_isInitialized && "SVD is not initialized."); + eigen_assert(computeU() && computeV() && "SVDBase::solve(): Both unitaries U and V are required to be computed (thin unitaries suffice)."); + eigen_assert((Transpose_?cols():rows())==b.rows() && "SVDBase::solve(): invalid number of rows of the right hand side matrix b"); + } // return true if already allocated bool allocate(Index rows, Index cols, unsigned int computationOptions) ; @@ -263,17 +285,30 @@ template<typename Derived> template<typename RhsType, typename DstType> void SVDBase<Derived>::_solve_impl(const RhsType &rhs, DstType &dst) const { - eigen_assert(rhs.rows() == rows()); - // A = U S V^* // So A^{-1} = V S^{-1} U^* - Matrix<Scalar, Dynamic, RhsType::ColsAtCompileTime, 0, MatrixType::MaxRowsAtCompileTime, RhsType::MaxColsAtCompileTime> tmp; + Matrix<typename RhsType::Scalar, Dynamic, RhsType::ColsAtCompileTime, 0, MatrixType::MaxRowsAtCompileTime, RhsType::MaxColsAtCompileTime> tmp; Index l_rank = rank(); tmp.noalias() = m_matrixU.leftCols(l_rank).adjoint() * rhs; tmp = m_singularValues.head(l_rank).asDiagonal().inverse() * tmp; dst = m_matrixV.leftCols(l_rank) * tmp; } + +template<typename Derived> +template<bool Conjugate, typename RhsType, typename DstType> +void SVDBase<Derived>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const +{ + // A = U S V^* + // So A^{-*} = U S^{-1} V^* + // And A^{-T} = U_conj S^{-1} V^T + Matrix<typename RhsType::Scalar, Dynamic, RhsType::ColsAtCompileTime, 0, MatrixType::MaxRowsAtCompileTime, RhsType::MaxColsAtCompileTime> tmp; + Index l_rank = rank(); + + tmp.noalias() = m_matrixV.leftCols(l_rank).transpose().template conjugateIf<Conjugate>() * rhs; + tmp = m_singularValues.head(l_rank).asDiagonal().inverse() * tmp; + dst = m_matrixU.template conjugateIf<!Conjugate>().leftCols(l_rank) * tmp; +} #endif template<typename MatrixType> |