From 433bce5c3a18e0929e95bae6cb142f1ad920e5ac Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 2 Dec 2014 17:30:57 +0100 Subject: UmfPack support: fix redundant evaluation/copies when calling compute() and support generic expressions as input --- Eigen/src/UmfPackSupport/UmfPackSupport.h | 89 +++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 27 deletions(-) (limited to 'Eigen/src/UmfPackSupport') diff --git a/Eigen/src/UmfPackSupport/UmfPackSupport.h b/Eigen/src/UmfPackSupport/UmfPackSupport.h index 29ee516d9..b8b216d5e 100644 --- a/Eigen/src/UmfPackSupport/UmfPackSupport.h +++ b/Eigen/src/UmfPackSupport/UmfPackSupport.h @@ -107,6 +107,16 @@ inline int umfpack_get_determinant(std::complex *Mx, double *Ex, void *N return umfpack_zi_get_determinant(&mx_real,0,Ex,NumericHandle,User_Info); } +namespace internal { + template struct umfpack_helper_is_sparse_plain : false_type {}; + template + struct umfpack_helper_is_sparse_plain > + : true_type {}; + template + struct umfpack_helper_is_sparse_plain > + : true_type {}; +} + /** \ingroup UmfPackSupport_Module * \brief A sparse LU factorization and solver based on UmfPack * @@ -199,8 +209,11 @@ class UmfPackLU : public SparseSolverBase > template void compute(const InputMatrixType& matrix) { - analyzePattern(matrix); - factorize(matrix); + if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar()); + if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar()); + grapInput(matrix.derived()); + analyzePattern_impl(); + factorize_impl(); } /** Performs a symbolic decomposition on the sparcity of \a matrix. @@ -212,22 +225,12 @@ class UmfPackLU : public SparseSolverBase > template void analyzePattern(const InputMatrixType& matrix) { - if(m_symbolic) - umfpack_free_symbolic(&m_symbolic,Scalar()); - if(m_numeric) - umfpack_free_numeric(&m_numeric,Scalar()); + if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar()); + if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar()); - grapInput(matrix); + grapInput(matrix.derived()); - int errorCode = 0; - errorCode = umfpack_symbolic(matrix.rows(), matrix.cols(), m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, - &m_symbolic, 0, 0); - - m_isInitialized = true; - m_info = errorCode ? InvalidInput : Success; - m_analysisIsOk = true; - m_factorizationIsOk = false; - m_extractedDataAreDirty = true; + analyzePattern_impl(); } /** Performs a numeric decomposition of \a matrix @@ -243,15 +246,9 @@ class UmfPackLU : public SparseSolverBase > if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar()); - grapInput(matrix); - - int errorCode; - errorCode = umfpack_numeric(m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, - m_symbolic, &m_numeric, 0, 0); - - m_info = errorCode ? NumericalIssue : Success; - m_factorizationIsOk = true; - m_extractedDataAreDirty = true; + grapInput(matrix.derived()); + + factorize_impl(); } #ifndef EIGEN_PARSED_BY_DOXYGEN @@ -266,7 +263,6 @@ class UmfPackLU : public SparseSolverBase > protected: - void init() { m_info = InvalidInput; @@ -280,7 +276,7 @@ class UmfPackLU : public SparseSolverBase > } template - void grapInput(const InputMatrixType& mat) + void grapInput_impl(const InputMatrixType& mat, internal::true_type) { m_copyMatrix.resize(mat.rows(), mat.cols()); if( ((MatrixType::Flags&RowMajorBit)==RowMajorBit) || sizeof(typename MatrixType::Index)!=sizeof(int) || !mat.isCompressed() ) @@ -298,6 +294,45 @@ class UmfPackLU : public SparseSolverBase > m_valuePtr = mat.valuePtr(); } } + + template + void grapInput_impl(const InputMatrixType& mat, internal::false_type) + { + m_copyMatrix = mat; + m_outerIndexPtr = m_copyMatrix.outerIndexPtr(); + m_innerIndexPtr = m_copyMatrix.innerIndexPtr(); + m_valuePtr = m_copyMatrix.valuePtr(); + } + + template + void grapInput(const InputMatrixType& mat) + { + grapInput_impl(mat, internal::umfpack_helper_is_sparse_plain()); + } + + void analyzePattern_impl() + { + int errorCode = 0; + errorCode = umfpack_symbolic(m_copyMatrix.rows(), m_copyMatrix.cols(), m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, + &m_symbolic, 0, 0); + + m_isInitialized = true; + m_info = errorCode ? InvalidInput : Success; + m_analysisIsOk = true; + m_factorizationIsOk = false; + m_extractedDataAreDirty = true; + } + + void factorize_impl() + { + int errorCode; + errorCode = umfpack_numeric(m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, + m_symbolic, &m_numeric, 0, 0); + + m_info = errorCode ? NumericalIssue : Success; + m_factorizationIsOk = true; + m_extractedDataAreDirty = true; + } // cached data to reduce reallocation, etc. mutable LUMatrixType m_l; -- cgit v1.2.3