diff options
Diffstat (limited to 'unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h')
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h | 107 |
1 files changed, 81 insertions, 26 deletions
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h index a809609d5..ca5a604fc 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h @@ -29,30 +29,6 @@ struct recompose_complex_schur<0> { res = (U * (T.template triangularView<Upper>() * U.adjoint())).real(); } }; -template<typename Derived, typename _Lhs, typename _Rhs> -struct traits<MatrixPowerProductBase<Derived,_Lhs,_Rhs> > -{ - typedef MatrixXpr XprKind; - typedef typename remove_all<_Lhs>::type Lhs; - typedef typename remove_all<_Rhs>::type Rhs; - typedef typename remove_all<Derived>::type PlainObject; - typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; - typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, - typename traits<Rhs>::StorageKind>::ret StorageKind; - typedef typename promote_index_type<typename traits<Lhs>::Index, - typename traits<Rhs>::Index>::type Index; - - enum { - RowsAtCompileTime = traits<Lhs>::RowsAtCompileTime, - ColsAtCompileTime = traits<Rhs>::ColsAtCompileTime, - MaxRowsAtCompileTime = traits<Lhs>::MaxRowsAtCompileTime, - MaxColsAtCompileTime = traits<Rhs>::MaxColsAtCompileTime, - Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) - | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, - CoeffReadCost = 0 - }; -}; - template<typename T> inline int binary_powering_cost(T p, int* squarings) { @@ -121,7 +97,8 @@ inline int matrix_power_get_pade_degree(long double normIminusT) } } // namespace internal -template<typename MatrixType, int UpLo=Upper> class MatrixPowerTriangularAtomic +template<typename MatrixType, int UpLo=Upper> +class MatrixPowerTriangularAtomic { private: typedef typename MatrixType::Scalar Scalar; @@ -239,10 +216,88 @@ void MatrixPowerTriangularAtomic<MatrixType,UpLo>::computeBig(MatrixType& res, R compute2x2(res, p); } +#define EIGEN_MATRIX_POWER_PUBLIC_INTERFACE(Derived) \ + typedef MatrixPowerBase<Derived<MatrixType>,MatrixType> Base; \ + using typename Base::Scalar; \ + using typename Base::RealScalar; \ + using typename Base::ComplexMatrix; \ + using typename Base::RealArray; + #define EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(Derived) \ - typedef MatrixPowerProductBase<Derived, Lhs, Rhs > Base; \ + typedef MatrixPowerProductBase<Derived, Lhs, Rhs> Base; \ EIGEN_DENSE_PUBLIC_INTERFACE(Derived) +namespace internal { +template<typename Derived, typename _Lhs, typename _Rhs> +struct traits<MatrixPowerProductBase<Derived,_Lhs,_Rhs> > +{ + typedef MatrixXpr XprKind; + typedef typename remove_all<_Lhs>::type Lhs; + typedef typename remove_all<_Rhs>::type Rhs; + typedef typename remove_all<Derived>::type PlainObject; + typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; + typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, + typename traits<Rhs>::StorageKind>::ret StorageKind; + typedef typename promote_index_type<typename traits<Lhs>::Index, + typename traits<Rhs>::Index>::type Index; + + enum { + RowsAtCompileTime = traits<Lhs>::RowsAtCompileTime, + ColsAtCompileTime = traits<Rhs>::ColsAtCompileTime, + MaxRowsAtCompileTime = traits<Lhs>::MaxRowsAtCompileTime, + MaxColsAtCompileTime = traits<Rhs>::MaxColsAtCompileTime, + Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) + | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, + CoeffReadCost = 0 + }; +}; +} // namespace internal + +template<typename Derived, typename MatrixType> +class MatrixPowerBase +{ + protected: + static const int Rows = MatrixType::RowsAtCompileTime; + static const int Cols = MatrixType::ColsAtCompileTime; + static const int Options = MatrixType::Options; + static const int MaxRows = MatrixType::MaxRowsAtCompileTime; + static const int MaxCols = MatrixType::MaxColsAtCompileTime; + + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::RealScalar RealScalar; + typedef typename MatrixType::Index Index; + typedef Matrix<std::complex<RealScalar>,Rows,Cols,Options,MaxRows,MaxCols> ComplexMatrix; + typedef Array<RealScalar,Rows,1,ColMajor,MaxRows> RealArray; + + const MatrixType& m_A; + const bool m_del; // whether to delete the pointer at destruction + + public: + explicit MatrixPowerBase(const MatrixType& A) : + m_A(A), + m_del(false) + { /* empty body */ } + + template<typename OtherDerived> + explicit MatrixPowerBase(const MatrixBase<OtherDerived>& A) : + m_A(*new MatrixType(A)), + m_del(true) + { /* empty body */ } + + ~MatrixPowerBase() + { if (m_del) delete &m_A; } + + void compute(MatrixType& res, RealScalar p) + { static_cast<Derived*>(this)->compute(res,p); } + + template<typename OtherDerived, typename ResultType> + void compute(const OtherDerived& b, ResultType& res, RealScalar p) + { static_cast<Derived*>(this)->compute(b,res,p); } + + Index rows() const { return m_A.rows(); } + Index cols() const { return m_A.cols(); } +}; + template<typename Derived, typename Lhs, typename Rhs> class MatrixPowerProductBase : public MatrixBase<Derived> { |