diff options
author | Gael Guennebaud <g.gael@free.fr> | 2014-12-02 17:30:57 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2014-12-02 17:30:57 +0100 |
commit | 433bce5c3a18e0929e95bae6cb142f1ad920e5ac (patch) | |
tree | 72085f5f8ec7c83f1c3f694b2bc653b8475b6b3c /Eigen/src/UmfPackSupport | |
parent | 775f7e5fbb1e40bc227f459415c6a0dcfa5cd61a (diff) |
UmfPack support: fix redundant evaluation/copies when calling compute() and support generic expressions as input
Diffstat (limited to 'Eigen/src/UmfPackSupport')
-rw-r--r-- | Eigen/src/UmfPackSupport/UmfPackSupport.h | 89 |
1 files changed, 62 insertions, 27 deletions
diff --git a/Eigen/src/UmfPackSupport/UmfPackSupport.h b/Eigen/src/UmfPackSupport/UmfPackSupport.h index 29ee516d9..b8b216d5e 100644 --- a/Eigen/src/UmfPackSupport/UmfPackSupport.h +++ b/Eigen/src/UmfPackSupport/UmfPackSupport.h @@ -107,6 +107,16 @@ inline int umfpack_get_determinant(std::complex<double> *Mx, double *Ex, void *N return umfpack_zi_get_determinant(&mx_real,0,Ex,NumericHandle,User_Info); } +namespace internal { + template<typename T> struct umfpack_helper_is_sparse_plain : false_type {}; + template<typename Scalar, int Options, typename StorageIndex> + struct umfpack_helper_is_sparse_plain<SparseMatrix<Scalar,Options,StorageIndex> > + : true_type {}; + template<typename Scalar, int Options, typename StorageIndex> + struct umfpack_helper_is_sparse_plain<MappedSparseMatrix<Scalar,Options,StorageIndex> > + : true_type {}; +} + /** \ingroup UmfPackSupport_Module * \brief A sparse LU factorization and solver based on UmfPack * @@ -199,8 +209,11 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> > template<typename InputMatrixType> void compute(const InputMatrixType& matrix) { - analyzePattern(matrix); - factorize(matrix); + if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar()); + if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar()); + grapInput(matrix.derived()); + analyzePattern_impl(); + factorize_impl(); } /** Performs a symbolic decomposition on the sparcity of \a matrix. @@ -212,22 +225,12 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> > template<typename InputMatrixType> void analyzePattern(const InputMatrixType& matrix) { - if(m_symbolic) - umfpack_free_symbolic(&m_symbolic,Scalar()); - if(m_numeric) - umfpack_free_numeric(&m_numeric,Scalar()); + if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar()); + if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar()); - grapInput(matrix); + grapInput(matrix.derived()); - int errorCode = 0; - errorCode = umfpack_symbolic(matrix.rows(), matrix.cols(), m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, - &m_symbolic, 0, 0); - - m_isInitialized = true; - m_info = errorCode ? InvalidInput : Success; - m_analysisIsOk = true; - m_factorizationIsOk = false; - m_extractedDataAreDirty = true; + analyzePattern_impl(); } /** Performs a numeric decomposition of \a matrix @@ -243,15 +246,9 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> > if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar()); - grapInput(matrix); - - int errorCode; - errorCode = umfpack_numeric(m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, - m_symbolic, &m_numeric, 0, 0); - - m_info = errorCode ? NumericalIssue : Success; - m_factorizationIsOk = true; - m_extractedDataAreDirty = true; + grapInput(matrix.derived()); + + factorize_impl(); } #ifndef EIGEN_PARSED_BY_DOXYGEN @@ -266,7 +263,6 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> > protected: - void init() { m_info = InvalidInput; @@ -280,7 +276,7 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> > } template<typename InputMatrixType> - void grapInput(const InputMatrixType& mat) + void grapInput_impl(const InputMatrixType& mat, internal::true_type) { m_copyMatrix.resize(mat.rows(), mat.cols()); if( ((MatrixType::Flags&RowMajorBit)==RowMajorBit) || sizeof(typename MatrixType::Index)!=sizeof(int) || !mat.isCompressed() ) @@ -298,6 +294,45 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> > m_valuePtr = mat.valuePtr(); } } + + template<typename InputMatrixType> + void grapInput_impl(const InputMatrixType& mat, internal::false_type) + { + m_copyMatrix = mat; + m_outerIndexPtr = m_copyMatrix.outerIndexPtr(); + m_innerIndexPtr = m_copyMatrix.innerIndexPtr(); + m_valuePtr = m_copyMatrix.valuePtr(); + } + + template<typename InputMatrixType> + void grapInput(const InputMatrixType& mat) + { + grapInput_impl(mat, internal::umfpack_helper_is_sparse_plain<InputMatrixType>()); + } + + void analyzePattern_impl() + { + int errorCode = 0; + errorCode = umfpack_symbolic(m_copyMatrix.rows(), m_copyMatrix.cols(), m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, + &m_symbolic, 0, 0); + + m_isInitialized = true; + m_info = errorCode ? InvalidInput : Success; + m_analysisIsOk = true; + m_factorizationIsOk = false; + m_extractedDataAreDirty = true; + } + + void factorize_impl() + { + int errorCode; + errorCode = umfpack_numeric(m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, + m_symbolic, &m_numeric, 0, 0); + + m_info = errorCode ? NumericalIssue : Success; + m_factorizationIsOk = true; + m_extractedDataAreDirty = true; + } // cached data to reduce reallocation, etc. mutable LUMatrixType m_l; |