diff options
-rw-r--r-- | Eigen/src/Core/MatrixBase.h | 3 | ||||
-rw-r--r-- | Eigen/src/Core/NoAlias.h | 4 | ||||
-rw-r--r-- | Eigen/src/Core/util/ForwardDeclarations.h | 2 | ||||
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixPower.h | 75 | ||||
-rw-r--r-- | unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h | 43 |
5 files changed, 72 insertions, 55 deletions
diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index c00c1488c..f138b12d2 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -162,6 +162,9 @@ template<typename Derived> class MatrixBase #ifndef EIGEN_PARSED_BY_DOXYGEN template<typename ProductDerived, typename Lhs, typename Rhs> Derived& lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other); + + template<typename ProductDerived, typename Lhs, typename Rhs> + Derived& lazyAssign(const MatrixPowerProductBase<ProductDerived, Lhs,Rhs>& other); #endif // not EIGEN_PARSED_BY_DOXYGEN template<typename OtherDerived> diff --git a/Eigen/src/Core/NoAlias.h b/Eigen/src/Core/NoAlias.h index fcf2c479c..ac1396f68 100644 --- a/Eigen/src/Core/NoAlias.h +++ b/Eigen/src/Core/NoAlias.h @@ -81,8 +81,8 @@ class NoAlias EIGEN_STRONG_INLINE ExpressionType& operator-=(const CoeffBasedProduct<Lhs,Rhs,NestingFlags>& other) { return m_expression.derived() -= CoeffBasedProduct<Lhs,Rhs,NestByRefBit>(other.lhs(), other.rhs()); } - template<typename Derived> - EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived>& other) + template<typename Derived, typename Lhs, typename Rhs> + EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived,Lhs,Rhs>& other) { other.derived().evalTo(m_expression); return m_expression; } #endif diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 1a3e14b30..58e1d87dc 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -272,7 +272,7 @@ template<typename Derived> class MatrixFunctionReturnValue; template<typename Derived> class MatrixSquareRootReturnValue; template<typename Derived> class MatrixLogarithmReturnValue; template<typename Derived> class MatrixPowerReturnValue; -template<typename Derived> class MatrixPowerProductBase; +template<typename Derived, typename Lhs, typename Rhs> class MatrixPowerProductBase; namespace internal { template <typename Scalar> diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h index 08affb2b5..7aeb69c00 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h @@ -55,14 +55,14 @@ template<typename MatrixType> class MatrixPower RealScalar modfAndInit(RealScalar, RealScalar*); - template<typename PlainObject, typename ResultType> - void apply(const PlainObject&, ResultType&, bool&); + template<typename Derived, typename ResultType> + void apply(const Derived&, ResultType&, bool&); template<typename ResultType> void computeIntPower(ResultType&, RealScalar); - template<typename PlainObject, typename ResultType> - void computeIntPower(const PlainObject&, ResultType&, RealScalar); + template<typename Derived, typename ResultType> + void computeIntPower(const Derived&, ResultType&, RealScalar); template<typename ResultType> void computeFracPower(ResultType&, RealScalar); @@ -101,8 +101,8 @@ template<typename MatrixType> class MatrixPower * \param[out] res \f$ A^p b \f$, where A is specified in the * constructor. */ - template<typename PlainObject, typename ResultType> - void compute(const PlainObject& b, ResultType& res, RealScalar p); + template<typename Derived, typename ResultType> + void compute(const Derived& b, ResultType& res, RealScalar p); Index rows() const { return m_A.rows(); } Index cols() const { return m_A.cols(); } @@ -133,8 +133,8 @@ void MatrixPower<MatrixType>::compute(MatrixType& res, RealScalar p) } template<typename MatrixType> -template<typename PlainObject, typename ResultType> -void MatrixPower<MatrixType>::compute(const PlainObject& b, ResultType& res, RealScalar p) +template<typename Derived, typename ResultType> +void MatrixPower<MatrixType>::compute(const Derived& b, ResultType& res, RealScalar p) { switch (m_A.cols()) { case 0: @@ -177,8 +177,8 @@ typename MatrixType::RealScalar MatrixPower<MatrixType>::modfAndInit(RealScalar } template<typename MatrixType> -template<typename PlainObject, typename ResultType> -void MatrixPower<MatrixType>::apply(const PlainObject& b, ResultType& res, bool& init) +template<typename Derived, typename ResultType> +void MatrixPower<MatrixType>::apply(const Derived& b, ResultType& res, bool& init) { if (init) res = m_tmp1 * res; @@ -206,8 +206,8 @@ void MatrixPower<MatrixType>::computeIntPower(ResultType& res, RealScalar p) } template<typename MatrixType> -template<typename PlainObject, typename ResultType> -void MatrixPower<MatrixType>::computeIntPower(const PlainObject& b, ResultType& res, RealScalar p) +template<typename Derived, typename ResultType> +void MatrixPower<MatrixType>::computeIntPower(const Derived& b, ResultType& res, RealScalar p) { if (b.cols() >= m_A.cols()) { m_tmp2 = MatrixType::Identity(m_A.rows(),m_A.cols()); @@ -262,14 +262,13 @@ void MatrixPower<MatrixType>::computeFracPower(ResultType& res, RealScalar p) } } -template<typename MatrixType, typename PlainObject> -class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> > +template<typename Lhs, typename Rhs> +class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs> { public: - typedef MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> > Base; - EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerMatrixProduct) + EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(MatrixPowerMatrixProduct) - MatrixPowerMatrixProduct(MatrixPower<MatrixType>& pow, const PlainObject& b, RealScalar p) + MatrixPowerMatrixProduct(MatrixPower<Lhs>& pow, const Rhs& b, RealScalar p) : m_pow(pow), m_b(b), m_p(p) { } template<typename ResultType> @@ -280,8 +279,8 @@ class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrix Index cols() const { return m_b.cols(); } private: - MatrixPower<MatrixType>& m_pow; - const PlainObject& m_b; + MatrixPower<Lhs>& m_pow; + const Rhs& m_b; const RealScalar m_p; MatrixPowerMatrixProduct& operator=(const MatrixPowerMatrixProduct&); }; @@ -323,7 +322,7 @@ class MatrixPowerReturnValue : public ReturnByValue<MatrixPowerReturnValue<Deriv */ template<typename ResultType> inline void evalTo(ResultType& res) const - { MatrixPower<typename Derived::PlainObject>(m_A).compute(res, m_p); } + { MatrixPower<typename Derived::PlainObject>(m_A.eval()).compute(res, m_p); } Index rows() const { return m_A.rows(); } Index cols() const { return m_A.cols(); } @@ -350,8 +349,8 @@ class MatrixPowerEvaluator { m_pow.compute(res, m_p); } template<typename Derived> - const MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject> operator*(const MatrixBase<Derived>& b) const - { return MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject>(m_pow, b.derived(), m_p); } + const MatrixPowerMatrixProduct<MatrixType, Derived> operator*(const MatrixBase<Derived>& b) const + { return MatrixPowerMatrixProduct<MatrixType, Derived>(m_pow, b.derived(), m_p); } Index rows() const { return m_pow.rows(); } Index cols() const { return m_pow.cols(); } @@ -363,9 +362,9 @@ class MatrixPowerEvaluator }; namespace internal { -template<typename MatrixType, typename PlainObject> -struct nested<MatrixPowerMatrixProduct<MatrixType,PlainObject> > -{ typedef PlainObject const& type; }; +template<typename MatrixType, typename Derived> +struct nested<MatrixPowerMatrixProduct<MatrixType,Derived> > +{ typedef typename MatrixPowerMatrixProduct<MatrixType,Derived>::PlainObject const& type; }; template<typename Derived> struct traits<MatrixPowerReturnValue<Derived> > @@ -375,28 +374,10 @@ template<typename MatrixType> struct traits<MatrixPowerEvaluator<MatrixType> > { typedef MatrixType ReturnType; }; -template<typename MatrixType, typename PlainObject> -struct traits<MatrixPowerMatrixProduct<MatrixType,PlainObject> > -{ - typedef MatrixXpr XprKind; - typedef typename scalar_product_traits<typename MatrixType::Scalar, typename PlainObject::Scalar>::ReturnType Scalar; - typedef typename promote_storage_type<typename traits<MatrixType>::StorageKind, - typename traits<PlainObject>::StorageKind>::ret StorageKind; - typedef typename promote_index_type<typename traits<MatrixType>::Index, - typename traits<PlainObject>::Index>::type Index; - - enum { - RowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::RowsAtCompileTime, - traits<PlainObject>::RowsAtCompileTime), - ColsAtCompileTime = traits<PlainObject>::ColsAtCompileTime, - MaxRowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::MaxRowsAtCompileTime, - traits<PlainObject>::MaxRowsAtCompileTime), - MaxColsAtCompileTime = traits<PlainObject>::MaxColsAtCompileTime, - Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) - | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, - CoeffReadCost = 0 - }; -}; +template<typename Lhs, typename Rhs> +struct traits<MatrixPowerMatrixProduct<Lhs,Rhs> > +: traits<MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs> > +{ }; } template<typename Derived> diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h index 0a18fe1c1..28617ff6f 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h @@ -29,9 +29,29 @@ struct recompose_complex_schur<0> { res = (U * (T.template triangularView<Upper>() * U.adjoint())).real(); } }; -template<typename Derived> -struct traits<MatrixPowerProductBase<Derived> > : traits<Derived> -{ }; +template<typename Derived, typename _Lhs, typename _Rhs> +struct traits<MatrixPowerProductBase<Derived,_Lhs,_Rhs> > +{ + typedef MatrixXpr XprKind; + typedef typename remove_all<_Lhs>::type Lhs; + typedef typename remove_all<_Rhs>::type Rhs; + typedef typename remove_all<Derived>::type PlainObject; + typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; + typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, + typename traits<Rhs>::StorageKind>::ret StorageKind; + typedef typename promote_index_type<typename traits<Lhs>::Index, + typename traits<Rhs>::Index>::type Index; + + enum { + RowsAtCompileTime = traits<Lhs>::RowsAtCompileTime, + ColsAtCompileTime = traits<Rhs>::ColsAtCompileTime, + MaxRowsAtCompileTime = traits<Lhs>::MaxRowsAtCompileTime, + MaxColsAtCompileTime = traits<Rhs>::MaxColsAtCompileTime, + Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) + | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, + CoeffReadCost = 0 + }; +}; template<typename T> inline int binary_powering_cost(T p, int* squarings) @@ -219,13 +239,18 @@ void MatrixPowerTriangularAtomic<MatrixType,UpLo>::computeBig(MatrixType& res, R compute2x2(res, p); } -template<typename Derived> +#define EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(Derived) \ + typedef MatrixPowerProductBase<Derived, Lhs, Rhs > Base; \ + EIGEN_DENSE_PUBLIC_INTERFACE(Derived) + +template<typename Derived, typename Lhs, typename Rhs> class MatrixPowerProductBase : public MatrixBase<Derived> { public: typedef MatrixBase<Derived> Base; - typedef typename Base::PlainObject PlainObject; EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerProductBase) + + typedef typename Base::PlainObject PlainObject; inline Index rows() const { return derived().rows(); } inline Index cols() const { return derived().cols(); } @@ -247,6 +272,14 @@ class MatrixPowerProductBase : public MatrixBase<Derived> mutable PlainObject m_result; }; +template<typename Derived> +template<typename ProductDerived, typename Lhs, typename Rhs> +Derived& MatrixBase<Derived>::lazyAssign(const MatrixPowerProductBase<ProductDerived,Lhs,Rhs>& other) +{ + other.derived().evalTo(derived()); + return derived(); +} + } // namespace Eigen #endif // EIGEN_MATRIX_POWER |