diff options
Diffstat (limited to 'Eigen/src/UmfPackSupport/UmfPackSupport.h')
-rw-r--r-- | Eigen/src/UmfPackSupport/UmfPackSupport.h | 195 |
1 files changed, 97 insertions, 98 deletions
diff --git a/Eigen/src/UmfPackSupport/UmfPackSupport.h b/Eigen/src/UmfPackSupport/UmfPackSupport.h index 3a48cecf7..a2bb75b09 100644 --- a/Eigen/src/UmfPackSupport/UmfPackSupport.h +++ b/Eigen/src/UmfPackSupport/UmfPackSupport.h @@ -107,6 +107,16 @@ inline int umfpack_get_determinant(std::complex<double> *Mx, double *Ex, void *N return umfpack_zi_get_determinant(&mx_real,0,Ex,NumericHandle,User_Info); } +namespace internal { + template<typename T> struct umfpack_helper_is_sparse_plain : false_type {}; + template<typename Scalar, int Options, typename StorageIndex> + struct umfpack_helper_is_sparse_plain<SparseMatrix<Scalar,Options,StorageIndex> > + : true_type {}; + template<typename Scalar, int Options, typename StorageIndex> + struct umfpack_helper_is_sparse_plain<MappedSparseMatrix<Scalar,Options,StorageIndex> > + : true_type {}; +} + /** \ingroup UmfPackSupport_Module * \brief A sparse LU factorization and solver based on UmfPack * @@ -121,9 +131,13 @@ inline int umfpack_get_determinant(std::complex<double> *Mx, double *Ex, void *N * \sa \ref TutorialSparseDirectSolvers */ template<typename _MatrixType> -class UmfPackLU : internal::noncopyable +class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> > { + protected: + typedef SparseSolverBase<UmfPackLU<_MatrixType> > Base; + using Base::m_isInitialized; public: + using Base::_solve_impl; typedef _MatrixType MatrixType; typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::RealScalar RealScalar; @@ -138,7 +152,7 @@ class UmfPackLU : internal::noncopyable UmfPackLU() { init(); } - UmfPackLU(const MatrixType& matrix) + explicit UmfPackLU(const MatrixType& matrix) { init(); compute(matrix); @@ -192,36 +206,14 @@ class UmfPackLU : internal::noncopyable * Note that the matrix should be column-major, and in compressed format for best performance. * \sa SparseMatrix::makeCompressed(). */ - void compute(const MatrixType& matrix) - { - analyzePattern(matrix); - factorize(matrix); - } - - /** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A. - * - * \sa compute() - */ - template<typename Rhs> - inline const internal::solve_retval<UmfPackLU, Rhs> solve(const MatrixBase<Rhs>& b) const - { - eigen_assert(m_isInitialized && "UmfPackLU is not initialized."); - eigen_assert(rows()==b.rows() - && "UmfPackLU::solve(): invalid number of rows of the right hand side matrix b"); - return internal::solve_retval<UmfPackLU, Rhs>(*this, b.derived()); - } - - /** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A. - * - * \sa compute() - */ - template<typename Rhs> - inline const internal::sparse_solve_retval<UmfPackLU, Rhs> solve(const SparseMatrixBase<Rhs>& b) const + template<typename InputMatrixType> + void compute(const InputMatrixType& matrix) { - eigen_assert(m_isInitialized && "UmfPackLU is not initialized."); - eigen_assert(rows()==b.rows() - && "UmfPackLU::solve(): invalid number of rows of the right hand side matrix b"); - return internal::sparse_solve_retval<UmfPackLU, Rhs>(*this, b.derived()); + if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar()); + if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar()); + grapInput(matrix.derived()); + analyzePattern_impl(); + factorize_impl(); } /** Performs a symbolic decomposition on the sparcity of \a matrix. @@ -230,23 +222,15 @@ class UmfPackLU : internal::noncopyable * * \sa factorize(), compute() */ - void analyzePattern(const MatrixType& matrix) + template<typename InputMatrixType> + void analyzePattern(const InputMatrixType& matrix) { - if(m_symbolic) - umfpack_free_symbolic(&m_symbolic,Scalar()); - if(m_numeric) - umfpack_free_numeric(&m_numeric,Scalar()); + if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar()); + if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar()); - grapInput(matrix); + grapInput(matrix.derived()); - int errorCode = 0; - errorCode = umfpack_symbolic(matrix.rows(), matrix.cols(), m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, - &m_symbolic, 0, 0); - - m_isInitialized = true; - m_info = errorCode ? InvalidInput : Success; - m_analysisIsOk = true; - m_factorizationIsOk = false; + analyzePattern_impl(); } /** Performs a numeric decomposition of \a matrix @@ -255,26 +239,22 @@ class UmfPackLU : internal::noncopyable * * \sa analyzePattern(), compute() */ - void factorize(const MatrixType& matrix) + template<typename InputMatrixType> + void factorize(const InputMatrixType& matrix) { eigen_assert(m_analysisIsOk && "UmfPackLU: you must first call analyzePattern()"); if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar()); - grapInput(matrix); - - int errorCode; - errorCode = umfpack_numeric(m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, - m_symbolic, &m_numeric, 0, 0); - - m_info = errorCode ? NumericalIssue : Success; - m_factorizationIsOk = true; + grapInput(matrix.derived()); + + factorize_impl(); } #ifndef EIGEN_PARSED_BY_DOXYGEN /** \internal */ template<typename BDerived,typename XDerived> - bool _solve(const MatrixBase<BDerived> &b, MatrixBase<XDerived> &x) const; + bool _solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XDerived> &x) const; #endif Scalar determinant() const; @@ -283,19 +263,20 @@ class UmfPackLU : internal::noncopyable protected: - void init() { - m_info = InvalidInput; - m_isInitialized = false; - m_numeric = 0; - m_symbolic = 0; - m_outerIndexPtr = 0; - m_innerIndexPtr = 0; - m_valuePtr = 0; + m_info = InvalidInput; + m_isInitialized = false; + m_numeric = 0; + m_symbolic = 0; + m_outerIndexPtr = 0; + m_innerIndexPtr = 0; + m_valuePtr = 0; + m_extractedDataAreDirty = true; } - void grapInput(const MatrixType& mat) + template<typename InputMatrixType> + void grapInput_impl(const InputMatrixType& mat, internal::true_type) { m_copyMatrix.resize(mat.rows(), mat.cols()); if( ((MatrixType::Flags&RowMajorBit)==RowMajorBit) || sizeof(typename MatrixType::Index)!=sizeof(int) || !mat.isCompressed() ) @@ -313,6 +294,45 @@ class UmfPackLU : internal::noncopyable m_valuePtr = mat.valuePtr(); } } + + template<typename InputMatrixType> + void grapInput_impl(const InputMatrixType& mat, internal::false_type) + { + m_copyMatrix = mat; + m_outerIndexPtr = m_copyMatrix.outerIndexPtr(); + m_innerIndexPtr = m_copyMatrix.innerIndexPtr(); + m_valuePtr = m_copyMatrix.valuePtr(); + } + + template<typename InputMatrixType> + void grapInput(const InputMatrixType& mat) + { + grapInput_impl(mat, internal::umfpack_helper_is_sparse_plain<InputMatrixType>()); + } + + void analyzePattern_impl() + { + int errorCode = 0; + errorCode = umfpack_symbolic(m_copyMatrix.rows(), m_copyMatrix.cols(), m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, + &m_symbolic, 0, 0); + + m_isInitialized = true; + m_info = errorCode ? InvalidInput : Success; + m_analysisIsOk = true; + m_factorizationIsOk = false; + m_extractedDataAreDirty = true; + } + + void factorize_impl() + { + int errorCode; + errorCode = umfpack_numeric(m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, + m_symbolic, &m_numeric, 0, 0); + + m_info = errorCode ? NumericalIssue : Success; + m_factorizationIsOk = true; + m_extractedDataAreDirty = true; + } // cached data to reduce reallocation, etc. mutable LUMatrixType m_l; @@ -328,7 +348,6 @@ class UmfPackLU : internal::noncopyable void* m_symbolic; mutable ComputationInfo m_info; - bool m_isInitialized; int m_factorizationIsOk; int m_analysisIsOk; mutable bool m_extractedDataAreDirty; @@ -376,7 +395,7 @@ typename UmfPackLU<MatrixType>::Scalar UmfPackLU<MatrixType>::determinant() cons template<typename MatrixType> template<typename BDerived,typename XDerived> -bool UmfPackLU<MatrixType>::_solve(const MatrixBase<BDerived> &b, MatrixBase<XDerived> &x) const +bool UmfPackLU<MatrixType>::_solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XDerived> &x) const { const int rhsCols = b.cols(); eigen_assert((BDerived::Flags&RowMajorBit)==0 && "UmfPackLU backend does not support non col-major rhs yet"); @@ -384,11 +403,22 @@ bool UmfPackLU<MatrixType>::_solve(const MatrixBase<BDerived> &b, MatrixBase<XDe eigen_assert(b.derived().data() != x.derived().data() && " Umfpack does not support inplace solve"); int errorCode; + Scalar* x_ptr = 0; + Matrix<Scalar,Dynamic,1> x_tmp; + if(x.innerStride()!=1) + { + x_tmp.resize(x.rows()); + x_ptr = x_tmp.data(); + } for (int j=0; j<rhsCols; ++j) { + if(x.innerStride()==1) + x_ptr = &x.col(j).coeffRef(0); errorCode = umfpack_solve(UMFPACK_A, m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, - &x.col(j).coeffRef(0), &b.const_cast_derived().col(j).coeffRef(0), m_numeric, 0, 0); + x_ptr, &b.const_cast_derived().col(j).coeffRef(0), m_numeric, 0, 0); + if(x.innerStride()!=1) + x.col(j) = x_tmp; if (errorCode!=0) return false; } @@ -396,37 +426,6 @@ bool UmfPackLU<MatrixType>::_solve(const MatrixBase<BDerived> &b, MatrixBase<XDe return true; } - -namespace internal { - -template<typename _MatrixType, typename Rhs> -struct solve_retval<UmfPackLU<_MatrixType>, Rhs> - : solve_retval_base<UmfPackLU<_MatrixType>, Rhs> -{ - typedef UmfPackLU<_MatrixType> Dec; - EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs) - - template<typename Dest> void evalTo(Dest& dst) const - { - dec()._solve(rhs(),dst); - } -}; - -template<typename _MatrixType, typename Rhs> -struct sparse_solve_retval<UmfPackLU<_MatrixType>, Rhs> - : sparse_solve_retval_base<UmfPackLU<_MatrixType>, Rhs> -{ - typedef UmfPackLU<_MatrixType> Dec; - EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs) - - template<typename Dest> void evalTo(Dest& dst) const - { - this->defaultEvalTo(dst); - } -}; - -} // end namespace internal - } // end namespace Eigen #endif // EIGEN_UMFPACKSUPPORT_H |