aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/SolverBase.h
blob: 5014610420f3e8cea48f0b3fef0a345f1a920f8d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_SOLVERBASE_H
#define EIGEN_SOLVERBASE_H

namespace Eigen {

namespace internal {

template<typename Derived>
struct solve_assertion {
    template<bool Transpose_, typename Rhs>
    static void run(const Derived& solver, const Rhs& b) { solver.template _check_solve_assertion<Transpose_>(b); }
};

template<typename Derived>
struct solve_assertion<Transpose<Derived> >
{
    typedef Transpose<Derived> type;

    template<bool Transpose_, typename Rhs>
    static void run(const type& transpose, const Rhs& b)
    {
        internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<true>(transpose.nestedExpression(), b);
    }
};

template<typename Scalar, typename Derived>
struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > >
{
    typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > type;

    template<bool Transpose_, typename Rhs>
    static void run(const type& adjoint, const Rhs& b)
    {
        internal::solve_assertion<typename internal::remove_all<Transpose<Derived> >::type>::template run<true>(adjoint.nestedExpression(), b);
    }
};
} // end namespace internal

/** \class SolverBase
  * \brief A base class for matrix decomposition and solvers
  *
  * \tparam Derived the actual type of the decomposition/solver.
  *
  * Any matrix decomposition inheriting this base class provide the following API:
  *
  * \code
  * MatrixType A, b, x;
  * DecompositionType dec(A);
  * x = dec.solve(b);             // solve A   * x = b
  * x = dec.transpose().solve(b); // solve A^T * x = b
  * x = dec.adjoint().solve(b);   // solve A'  * x = b
  * \endcode
  *
  * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors.
  *
  * \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR, class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase
  */
template<typename Derived>
class SolverBase : public EigenBase<Derived>
{
  public:

    typedef EigenBase<Derived> Base;
    typedef typename internal::traits<Derived>::Scalar Scalar;
    typedef Scalar CoeffReturnType;

    template<typename Derived_>
    friend struct internal::solve_assertion;

    enum {
      RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime,
      ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime,
      SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime,
                                                          internal::traits<Derived>::ColsAtCompileTime>::ret),
      MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime,
      MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime,
      MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime,
                                                             internal::traits<Derived>::MaxColsAtCompileTime>::ret),
      IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1
                           || internal::traits<Derived>::MaxColsAtCompileTime == 1,
      NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0 : bool(IsVectorAtCompileTime) ? 1 : 2
    };

    /** Default constructor */
    SolverBase()
    {}

    ~SolverBase()
    {}

    using Base::derived;

    /** \returns an expression of the solution x of \f$ A x = b \f$ using the current decomposition of A.
      */
    template<typename Rhs>
    inline const Solve<Derived, Rhs>
    solve(const MatrixBase<Rhs>& b) const
    {
      internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<false>(derived(), b);
      return Solve<Derived, Rhs>(derived(), b.derived());
    }

    /** \internal the return type of transpose() */
    typedef typename internal::add_const<Transpose<const Derived> >::type ConstTransposeReturnType;
    /** \returns an expression of the transposed of the factored matrix.
      *
      * A typical usage is to solve for the transposed problem A^T x = b:
      * \code x = dec.transpose().solve(b); \endcode
      *
      * \sa adjoint(), solve()
      */
    inline ConstTransposeReturnType transpose() const
    {
      return ConstTransposeReturnType(derived());
    }

    /** \internal the return type of adjoint() */
    typedef typename internal::conditional<NumTraits<Scalar>::IsComplex,
                        CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>,
                        ConstTransposeReturnType
                     >::type AdjointReturnType;
    /** \returns an expression of the adjoint of the factored matrix
      *
      * A typical usage is to solve for the adjoint problem A' x = b:
      * \code x = dec.adjoint().solve(b); \endcode
      *
      * For real scalar types, this function is equivalent to transpose().
      *
      * \sa transpose(), solve()
      */
    inline AdjointReturnType adjoint() const
    {
      return AdjointReturnType(derived().transpose());
    }

  protected:

    template<bool Transpose_, typename Rhs>
    void _check_solve_assertion(const Rhs& b) const {
        EIGEN_ONLY_USED_FOR_DEBUG(b);
        eigen_assert(derived().m_isInitialized && "Solver is not initialized.");
        eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "SolverBase::solve(): invalid number of rows of the right hand side matrix b");
    }
};

namespace internal {

template<typename Derived>
struct generic_xpr_base<Derived, MatrixXpr, SolverStorage>
{
  typedef SolverBase<Derived> type;

};

} // end namespace internal

} // end namespace Eigen

#endif // EIGEN_SOLVERBASE_H