diff options
Diffstat (limited to 'Eigen/src/Core/SolverBase.h')
-rw-r--r-- | Eigen/src/Core/SolverBase.h | 40 |
1 files changed, 38 insertions, 2 deletions
diff --git a/Eigen/src/Core/SolverBase.h b/Eigen/src/Core/SolverBase.h index 702a5485c..055d3ddc1 100644 --- a/Eigen/src/Core/SolverBase.h +++ b/Eigen/src/Core/SolverBase.h @@ -14,8 +14,35 @@ namespace Eigen { namespace internal { +template<typename Derived> +struct solve_assertion { + template<bool Transpose_, typename Rhs> + static void run(const Derived& solver, const Rhs& b) { solver.template _check_solve_assertion<Transpose_>(b); } +}; + +template<typename Derived> +struct solve_assertion<Transpose<Derived> > +{ + typedef Transpose<Derived> type; + + template<bool Transpose_, typename Rhs> + static void run(const type& transpose, const Rhs& b) + { + internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<true>(transpose.nestedExpression(), b); + } +}; +template<typename Scalar, typename Derived> +struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > > +{ + typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > type; + template<bool Transpose_, typename Rhs> + static void run(const type& adjoint, const Rhs& b) + { + internal::solve_assertion<typename internal::remove_all<Transpose<Derived> >::type>::template run<true>(adjoint.nestedExpression(), b); + } +}; } // end namespace internal /** \class SolverBase @@ -35,7 +62,7 @@ namespace internal { * * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors. * - * \sa class PartialPivLU, class FullPivLU + * \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR, class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase */ template<typename Derived> class SolverBase : public EigenBase<Derived> @@ -46,6 +73,9 @@ class SolverBase : public EigenBase<Derived> typedef typename internal::traits<Derived>::Scalar Scalar; typedef Scalar CoeffReturnType; + template<typename Derived_> + friend struct internal::solve_assertion; + enum { RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime, ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime, @@ -75,7 +105,7 @@ class SolverBase : public EigenBase<Derived> inline const Solve<Derived, Rhs> solve(const MatrixBase<Rhs>& b) const { - eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b"); + internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<false>(derived(), b); return Solve<Derived, Rhs>(derived(), b.derived()); } @@ -113,6 +143,12 @@ class SolverBase : public EigenBase<Derived> } protected: + + template<bool Transpose_, typename Rhs> + void _check_solve_assertion(const Rhs& b) const { + eigen_assert(derived().m_isInitialized && "Solver is not initialized."); + eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "SolverBase::solve(): invalid number of rows of the right hand side matrix b"); + } }; namespace internal { |