aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/UmfPackSupport
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2015-06-09 23:30:06 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2015-06-09 23:30:06 +0200
commit0b2cbb2bdc6ea55654da7a11887e823cd619b842 (patch)
tree3b18db8ceef20c6b767d2ee2c83ff74b15119ea0 /Eigen/src/UmfPackSupport
parentfeaf76c00103df5f0ae2d7a365a91f70253922cd (diff)
bug #897: make umfpack support use Ref<>
Diffstat (limited to 'Eigen/src/UmfPackSupport')
-rw-r--r--Eigen/src/UmfPackSupport/UmfPackSupport.h101
1 files changed, 38 insertions, 63 deletions
diff --git a/Eigen/src/UmfPackSupport/UmfPackSupport.h b/Eigen/src/UmfPackSupport/UmfPackSupport.h
index f3a6e7c0e..0a5043ef2 100644
--- a/Eigen/src/UmfPackSupport/UmfPackSupport.h
+++ b/Eigen/src/UmfPackSupport/UmfPackSupport.h
@@ -107,15 +107,6 @@ inline int umfpack_get_determinant(std::complex<double> *Mx, double *Ex, void *N
return umfpack_zi_get_determinant(&mx_real,0,Ex,NumericHandle,User_Info);
}
-namespace internal {
- template<typename T> struct umfpack_helper_is_sparse_plain : false_type {};
- template<typename Scalar, int Options, typename StorageIndex>
- struct umfpack_helper_is_sparse_plain<SparseMatrix<Scalar,Options,StorageIndex> >
- : true_type {};
- template<typename Scalar, int Options, typename StorageIndex>
- struct umfpack_helper_is_sparse_plain<MappedSparseMatrix<Scalar,Options,StorageIndex> >
- : true_type {};
-}
/** \ingroup UmfPackSupport_Module
* \brief A sparse LU factorization and solver based on UmfPack
@@ -147,12 +138,18 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
typedef Matrix<int, MatrixType::RowsAtCompileTime, 1> IntColVectorType;
typedef SparseMatrix<Scalar> LUMatrixType;
typedef SparseMatrix<Scalar,ColMajor,int> UmfpackMatrixType;
+ typedef Ref<const UmfpackMatrixType, StandardCompressedFormat> UmfpackMatrixRef;
public:
- UmfPackLU() { init(); }
+ UmfPackLU()
+ : m_dummy(0,0), mp_matrix(m_dummy)
+ {
+ init();
+ }
explicit UmfPackLU(const MatrixType& matrix)
+ : mp_matrix(matrix)
{
init();
compute(matrix);
@@ -164,8 +161,8 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar());
}
- inline Index rows() const { return m_copyMatrix.rows(); }
- inline Index cols() const { return m_copyMatrix.cols(); }
+ inline Index rows() const { return mp_matrix.rows(); }
+ inline Index cols() const { return mp_matrix.cols(); }
/** \brief Reports whether previous computation was successful.
*
@@ -211,7 +208,7 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
{
if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar());
if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar());
- grapInput(matrix.derived());
+ grab(matrix.derived());
analyzePattern_impl();
factorize_impl();
}
@@ -228,7 +225,7 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar());
if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar());
- grapInput(matrix.derived());
+ grab(matrix.derived());
analyzePattern_impl();
}
@@ -246,7 +243,7 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
if(m_numeric)
umfpack_free_numeric(&m_numeric,Scalar());
- grapInput(matrix.derived());
+ grab(matrix.derived());
factorize_impl();
}
@@ -267,53 +264,16 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
m_isInitialized = false;
m_numeric = 0;
m_symbolic = 0;
- m_outerIndexPtr = 0;
- m_innerIndexPtr = 0;
- m_valuePtr = 0;
m_extractedDataAreDirty = true;
}
- template<typename InputMatrixType>
- void grapInput_impl(const InputMatrixType& mat, internal::true_type)
- {
- m_copyMatrix.resize(mat.rows(), mat.cols());
- if( ((MatrixType::Flags&RowMajorBit)==RowMajorBit) || sizeof(typename MatrixType::StorageIndex)!=sizeof(int) || !mat.isCompressed() )
- {
- // non supported input -> copy
- m_copyMatrix = mat;
- m_outerIndexPtr = m_copyMatrix.outerIndexPtr();
- m_innerIndexPtr = m_copyMatrix.innerIndexPtr();
- m_valuePtr = m_copyMatrix.valuePtr();
- }
- else
- {
- m_outerIndexPtr = mat.outerIndexPtr();
- m_innerIndexPtr = mat.innerIndexPtr();
- m_valuePtr = mat.valuePtr();
- }
- }
-
- template<typename InputMatrixType>
- void grapInput_impl(const InputMatrixType& mat, internal::false_type)
- {
- m_copyMatrix = mat;
- m_outerIndexPtr = m_copyMatrix.outerIndexPtr();
- m_innerIndexPtr = m_copyMatrix.innerIndexPtr();
- m_valuePtr = m_copyMatrix.valuePtr();
- }
-
- template<typename InputMatrixType>
- void grapInput(const InputMatrixType& mat)
- {
- grapInput_impl(mat, internal::umfpack_helper_is_sparse_plain<InputMatrixType>());
- }
-
void analyzePattern_impl()
{
int errorCode = 0;
- errorCode = umfpack_symbolic(internal::convert_index<int>(m_copyMatrix.rows()),
- internal::convert_index<int>(m_copyMatrix.cols()),
- m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, &m_symbolic, 0, 0);
+ errorCode = umfpack_symbolic(internal::convert_index<int>(mp_matrix.rows()),
+ internal::convert_index<int>(mp_matrix.cols()),
+ mp_matrix.outerIndexPtr(), mp_matrix.innerIndexPtr(), mp_matrix.valuePtr(),
+ &m_symbolic, 0, 0);
m_isInitialized = true;
m_info = errorCode ? InvalidInput : Success;
@@ -325,24 +285,39 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
void factorize_impl()
{
int errorCode;
- errorCode = umfpack_numeric(m_outerIndexPtr, m_innerIndexPtr, m_valuePtr,
+ errorCode = umfpack_numeric(mp_matrix.outerIndexPtr(), mp_matrix.innerIndexPtr(), mp_matrix.valuePtr(),
m_symbolic, &m_numeric, 0, 0);
m_info = errorCode ? NumericalIssue : Success;
m_factorizationIsOk = true;
m_extractedDataAreDirty = true;
}
-
+
+ template<typename MatrixDerived>
+ void grab(const EigenBase<MatrixDerived> &A)
+ {
+ mp_matrix.~UmfpackMatrixRef();
+ ::new (&mp_matrix) UmfpackMatrixRef(A.derived());
+ }
+
+ void grab(const UmfpackMatrixRef &A)
+ {
+ if(&(A.derived()) != &mp_matrix)
+ {
+ mp_matrix.~UmfpackMatrixRef();
+ ::new (&mp_matrix) UmfpackMatrixRef(A);
+ }
+ }
+
// cached data to reduce reallocation, etc.
mutable LUMatrixType m_l;
mutable LUMatrixType m_u;
mutable IntColVectorType m_p;
mutable IntRowVectorType m_q;
- UmfpackMatrixType m_copyMatrix;
- const Scalar* m_valuePtr;
- const int* m_outerIndexPtr;
- const int* m_innerIndexPtr;
+ UmfpackMatrixType m_dummy;
+ UmfpackMatrixRef mp_matrix;
+
void* m_numeric;
void* m_symbolic;
@@ -414,7 +389,7 @@ bool UmfPackLU<MatrixType>::_solve_impl(const MatrixBase<BDerived> &b, MatrixBas
if(x.innerStride()==1)
x_ptr = &x.col(j).coeffRef(0);
errorCode = umfpack_solve(UMFPACK_A,
- m_outerIndexPtr, m_innerIndexPtr, m_valuePtr,
+ mp_matrix.outerIndexPtr(), mp_matrix.innerIndexPtr(), mp_matrix.valuePtr(),
x_ptr, &b.const_cast_derived().col(j).coeffRef(0), m_numeric, 0, 0);
if(x.innerStride()!=1)
x.col(j) = x_tmp;