aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2015-12-07 12:23:22 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2015-12-07 12:23:22 +0100
commitb37036afce20e902cd5191a2a985f39b1f7e22e3 (patch)
tree4c7409d679d1ecbdf55b3ec518a16264fbbb7587 /Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
parentf4ca8ad9178b5fa1b83697e1a645e55d65df5639 (diff)
Implement wrapper for matrix-free iterative solvers
Diffstat (limited to 'Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h')
-rw-r--r--Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h162
1 files changed, 140 insertions, 22 deletions
diff --git a/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h b/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
index e51ff7280..3d62fef6e 100644
--- a/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
+++ b/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
@@ -12,6 +12,128 @@
namespace Eigen {
+namespace internal {
+
+template<typename MatrixType>
+struct is_ref_compatible_impl
+{
+private:
+ template <typename T0>
+ struct any_conversion
+ {
+ template <typename T> any_conversion(const volatile T&);
+ template <typename T> any_conversion(T&);
+ };
+ struct yes {int a[1];};
+ struct no {int a[2];};
+
+ template<typename T>
+ static yes test(const Ref<const T>&, int);
+ template<typename T>
+ static no test(any_conversion<T>, ...);
+
+public:
+ static MatrixType ms_from;
+ enum { value = sizeof(test<MatrixType>(ms_from, 0))==sizeof(yes) };
+};
+
+template<typename MatrixType>
+struct is_ref_compatible
+{
+ enum { value = is_ref_compatible_impl<typename remove_all<MatrixType>::type>::value };
+};
+
+template<typename MatrixType, bool MatrixFree = !internal::is_ref_compatible<MatrixType>::value>
+class generic_matrix_wrapper;
+
+// We have an explicit matrix at hand, compatible with Ref<>
+template<typename MatrixType>
+class generic_matrix_wrapper<MatrixType,false>
+{
+public:
+ typedef Ref<const MatrixType> ActualMatrixType;
+ template<int UpLo> struct ConstSelfAdjointViewReturnType {
+ typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type;
+ };
+
+ enum {
+ MatrixFree = false
+ };
+
+ generic_matrix_wrapper()
+ : m_dummy(0,0), m_matrix(m_dummy)
+ {}
+
+ template<typename InputType>
+ generic_matrix_wrapper(const InputType &mat)
+ : m_matrix(mat)
+ {}
+
+ const ActualMatrixType& matrix() const
+ {
+ return m_matrix;
+ }
+
+ template<typename MatrixDerived>
+ void grab(const EigenBase<MatrixDerived> &mat)
+ {
+ m_matrix.~Ref<const MatrixType>();
+ ::new (&m_matrix) Ref<const MatrixType>(mat.derived());
+ }
+
+ void grab(const Ref<const MatrixType> &mat)
+ {
+ if(&(mat.derived()) != &m_matrix)
+ {
+ m_matrix.~Ref<const MatrixType>();
+ ::new (&m_matrix) Ref<const MatrixType>(mat);
+ }
+ }
+
+protected:
+ MatrixType m_dummy; // used to default initialize the Ref<> object
+ ActualMatrixType m_matrix;
+};
+
+// MatrixType is not compatible with Ref<> -> matrix-free wrapper
+template<typename MatrixType>
+class generic_matrix_wrapper<MatrixType,true>
+{
+public:
+ typedef MatrixType ActualMatrixType;
+ template<int UpLo> struct ConstSelfAdjointViewReturnType
+ {
+ typedef ActualMatrixType Type;
+ };
+
+ enum {
+ MatrixFree = true
+ };
+
+ generic_matrix_wrapper()
+ : mp_matrix(0)
+ {}
+
+ generic_matrix_wrapper(const MatrixType &mat)
+ : mp_matrix(&mat)
+ {}
+
+ const ActualMatrixType& matrix() const
+ {
+ return *mp_matrix;
+ }
+
+ void grab(const MatrixType &mat)
+ {
+ mp_matrix = &mat;
+ }
+
+protected:
+ const ActualMatrixType *mp_matrix;
+};
+
+}
+
/** \ingroup IterativeLinearSolvers_Module
* \brief Base class for linear iterative solvers
*
@@ -42,7 +164,6 @@ public:
/** Default constructor. */
IterativeSolverBase()
- : m_dummy(0,0), mp_matrix(m_dummy)
{
init();
}
@@ -59,10 +180,10 @@ public:
*/
template<typename MatrixDerived>
explicit IterativeSolverBase(const EigenBase<MatrixDerived>& A)
- : mp_matrix(A.derived())
+ : m_matrixWrapper(A.derived())
{
init();
- compute(mp_matrix);
+ compute(matrix());
}
~IterativeSolverBase() {}
@@ -76,7 +197,7 @@ public:
Derived& analyzePattern(const EigenBase<MatrixDerived>& A)
{
grab(A.derived());
- m_preconditioner.analyzePattern(mp_matrix);
+ m_preconditioner.analyzePattern(matrix());
m_isInitialized = true;
m_analysisIsOk = true;
m_info = m_preconditioner.info();
@@ -97,7 +218,7 @@ public:
{
eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
grab(A.derived());
- m_preconditioner.factorize(mp_matrix);
+ m_preconditioner.factorize(matrix());
m_factorizationIsOk = true;
m_info = m_preconditioner.info();
return derived();
@@ -117,7 +238,7 @@ public:
Derived& compute(const EigenBase<MatrixDerived>& A)
{
grab(A.derived());
- m_preconditioner.compute(mp_matrix);
+ m_preconditioner.compute(matrix());
m_isInitialized = true;
m_analysisIsOk = true;
m_factorizationIsOk = true;
@@ -126,10 +247,10 @@ public:
}
/** \internal */
- Index rows() const { return mp_matrix.rows(); }
+ Index rows() const { return matrix().rows(); }
/** \internal */
- Index cols() const { return mp_matrix.cols(); }
+ Index cols() const { return matrix().cols(); }
/** \returns the tolerance threshold used by the stopping criteria.
* \sa setTolerance()
@@ -159,7 +280,7 @@ public:
*/
Index maxIterations() const
{
- return (m_maxIterations<0) ? 2*mp_matrix.cols() : m_maxIterations;
+ return (m_maxIterations<0) ? 2*matrix().cols() : m_maxIterations;
}
/** Sets the max number of iterations.
@@ -239,25 +360,22 @@ protected:
m_maxIterations = -1;
m_tolerance = NumTraits<Scalar>::epsilon();
}
-
- template<typename MatrixDerived>
- void grab(const EigenBase<MatrixDerived> &A)
+
+ typedef internal::generic_matrix_wrapper<MatrixType> MatrixWrapper;
+ typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType;
+
+ const ActualMatrixType& matrix() const
{
- mp_matrix.~Ref<const MatrixType>();
- ::new (&mp_matrix) Ref<const MatrixType>(A.derived());
+ return m_matrixWrapper.matrix();
}
- void grab(const Ref<const MatrixType> &A)
+ template<typename InputType>
+ void grab(const InputType &A)
{
- if(&(A.derived()) != &mp_matrix)
- {
- mp_matrix.~Ref<const MatrixType>();
- ::new (&mp_matrix) Ref<const MatrixType>(A);
- }
+ m_matrixWrapper.grab(A);
}
- MatrixType m_dummy;
- Ref<const MatrixType> mp_matrix;
+ MatrixWrapper m_matrixWrapper;
Preconditioner m_preconditioner;
Index m_maxIterations;