aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/MatrixBase.h3
-rw-r--r--Eigen/src/Core/NoAlias.h4
-rw-r--r--Eigen/src/Core/util/ForwardDeclarations.h2
-rw-r--r--unsupported/Eigen/src/MatrixFunctions/MatrixPower.h75
-rw-r--r--unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h43
5 files changed, 72 insertions, 55 deletions
diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h
index c00c1488c..f138b12d2 100644
--- a/Eigen/src/Core/MatrixBase.h
+++ b/Eigen/src/Core/MatrixBase.h
@@ -162,6 +162,9 @@ template<typename Derived> class MatrixBase
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other);
+
+ template<typename ProductDerived, typename Lhs, typename Rhs>
+ Derived& lazyAssign(const MatrixPowerProductBase<ProductDerived, Lhs,Rhs>& other);
#endif // not EIGEN_PARSED_BY_DOXYGEN
template<typename OtherDerived>
diff --git a/Eigen/src/Core/NoAlias.h b/Eigen/src/Core/NoAlias.h
index fcf2c479c..ac1396f68 100644
--- a/Eigen/src/Core/NoAlias.h
+++ b/Eigen/src/Core/NoAlias.h
@@ -81,8 +81,8 @@ class NoAlias
EIGEN_STRONG_INLINE ExpressionType& operator-=(const CoeffBasedProduct<Lhs,Rhs,NestingFlags>& other)
{ return m_expression.derived() -= CoeffBasedProduct<Lhs,Rhs,NestByRefBit>(other.lhs(), other.rhs()); }
- template<typename Derived>
- EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived>& other)
+ template<typename Derived, typename Lhs, typename Rhs>
+ EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived,Lhs,Rhs>& other)
{ other.derived().evalTo(m_expression); return m_expression; }
#endif
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h
index 1a3e14b30..58e1d87dc 100644
--- a/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -272,7 +272,7 @@ template<typename Derived> class MatrixFunctionReturnValue;
template<typename Derived> class MatrixSquareRootReturnValue;
template<typename Derived> class MatrixLogarithmReturnValue;
template<typename Derived> class MatrixPowerReturnValue;
-template<typename Derived> class MatrixPowerProductBase;
+template<typename Derived, typename Lhs, typename Rhs> class MatrixPowerProductBase;
namespace internal {
template <typename Scalar>
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h
index 08affb2b5..7aeb69c00 100644
--- a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h
+++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h
@@ -55,14 +55,14 @@ template<typename MatrixType> class MatrixPower
RealScalar modfAndInit(RealScalar, RealScalar*);
- template<typename PlainObject, typename ResultType>
- void apply(const PlainObject&, ResultType&, bool&);
+ template<typename Derived, typename ResultType>
+ void apply(const Derived&, ResultType&, bool&);
template<typename ResultType>
void computeIntPower(ResultType&, RealScalar);
- template<typename PlainObject, typename ResultType>
- void computeIntPower(const PlainObject&, ResultType&, RealScalar);
+ template<typename Derived, typename ResultType>
+ void computeIntPower(const Derived&, ResultType&, RealScalar);
template<typename ResultType>
void computeFracPower(ResultType&, RealScalar);
@@ -101,8 +101,8 @@ template<typename MatrixType> class MatrixPower
* \param[out] res \f$ A^p b \f$, where A is specified in the
* constructor.
*/
- template<typename PlainObject, typename ResultType>
- void compute(const PlainObject& b, ResultType& res, RealScalar p);
+ template<typename Derived, typename ResultType>
+ void compute(const Derived& b, ResultType& res, RealScalar p);
Index rows() const { return m_A.rows(); }
Index cols() const { return m_A.cols(); }
@@ -133,8 +133,8 @@ void MatrixPower<MatrixType>::compute(MatrixType& res, RealScalar p)
}
template<typename MatrixType>
-template<typename PlainObject, typename ResultType>
-void MatrixPower<MatrixType>::compute(const PlainObject& b, ResultType& res, RealScalar p)
+template<typename Derived, typename ResultType>
+void MatrixPower<MatrixType>::compute(const Derived& b, ResultType& res, RealScalar p)
{
switch (m_A.cols()) {
case 0:
@@ -177,8 +177,8 @@ typename MatrixType::RealScalar MatrixPower<MatrixType>::modfAndInit(RealScalar
}
template<typename MatrixType>
-template<typename PlainObject, typename ResultType>
-void MatrixPower<MatrixType>::apply(const PlainObject& b, ResultType& res, bool& init)
+template<typename Derived, typename ResultType>
+void MatrixPower<MatrixType>::apply(const Derived& b, ResultType& res, bool& init)
{
if (init)
res = m_tmp1 * res;
@@ -206,8 +206,8 @@ void MatrixPower<MatrixType>::computeIntPower(ResultType& res, RealScalar p)
}
template<typename MatrixType>
-template<typename PlainObject, typename ResultType>
-void MatrixPower<MatrixType>::computeIntPower(const PlainObject& b, ResultType& res, RealScalar p)
+template<typename Derived, typename ResultType>
+void MatrixPower<MatrixType>::computeIntPower(const Derived& b, ResultType& res, RealScalar p)
{
if (b.cols() >= m_A.cols()) {
m_tmp2 = MatrixType::Identity(m_A.rows(),m_A.cols());
@@ -262,14 +262,13 @@ void MatrixPower<MatrixType>::computeFracPower(ResultType& res, RealScalar p)
}
}
-template<typename MatrixType, typename PlainObject>
-class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> >
+template<typename Lhs, typename Rhs>
+class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs>
{
public:
- typedef MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> > Base;
- EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerMatrixProduct)
+ EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(MatrixPowerMatrixProduct)
- MatrixPowerMatrixProduct(MatrixPower<MatrixType>& pow, const PlainObject& b, RealScalar p)
+ MatrixPowerMatrixProduct(MatrixPower<Lhs>& pow, const Rhs& b, RealScalar p)
: m_pow(pow), m_b(b), m_p(p) { }
template<typename ResultType>
@@ -280,8 +279,8 @@ class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrix
Index cols() const { return m_b.cols(); }
private:
- MatrixPower<MatrixType>& m_pow;
- const PlainObject& m_b;
+ MatrixPower<Lhs>& m_pow;
+ const Rhs& m_b;
const RealScalar m_p;
MatrixPowerMatrixProduct& operator=(const MatrixPowerMatrixProduct&);
};
@@ -323,7 +322,7 @@ class MatrixPowerReturnValue : public ReturnByValue<MatrixPowerReturnValue<Deriv
*/
template<typename ResultType>
inline void evalTo(ResultType& res) const
- { MatrixPower<typename Derived::PlainObject>(m_A).compute(res, m_p); }
+ { MatrixPower<typename Derived::PlainObject>(m_A.eval()).compute(res, m_p); }
Index rows() const { return m_A.rows(); }
Index cols() const { return m_A.cols(); }
@@ -350,8 +349,8 @@ class MatrixPowerEvaluator
{ m_pow.compute(res, m_p); }
template<typename Derived>
- const MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject> operator*(const MatrixBase<Derived>& b) const
- { return MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject>(m_pow, b.derived(), m_p); }
+ const MatrixPowerMatrixProduct<MatrixType, Derived> operator*(const MatrixBase<Derived>& b) const
+ { return MatrixPowerMatrixProduct<MatrixType, Derived>(m_pow, b.derived(), m_p); }
Index rows() const { return m_pow.rows(); }
Index cols() const { return m_pow.cols(); }
@@ -363,9 +362,9 @@ class MatrixPowerEvaluator
};
namespace internal {
-template<typename MatrixType, typename PlainObject>
-struct nested<MatrixPowerMatrixProduct<MatrixType,PlainObject> >
-{ typedef PlainObject const& type; };
+template<typename MatrixType, typename Derived>
+struct nested<MatrixPowerMatrixProduct<MatrixType,Derived> >
+{ typedef typename MatrixPowerMatrixProduct<MatrixType,Derived>::PlainObject const& type; };
template<typename Derived>
struct traits<MatrixPowerReturnValue<Derived> >
@@ -375,28 +374,10 @@ template<typename MatrixType>
struct traits<MatrixPowerEvaluator<MatrixType> >
{ typedef MatrixType ReturnType; };
-template<typename MatrixType, typename PlainObject>
-struct traits<MatrixPowerMatrixProduct<MatrixType,PlainObject> >
-{
- typedef MatrixXpr XprKind;
- typedef typename scalar_product_traits<typename MatrixType::Scalar, typename PlainObject::Scalar>::ReturnType Scalar;
- typedef typename promote_storage_type<typename traits<MatrixType>::StorageKind,
- typename traits<PlainObject>::StorageKind>::ret StorageKind;
- typedef typename promote_index_type<typename traits<MatrixType>::Index,
- typename traits<PlainObject>::Index>::type Index;
-
- enum {
- RowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::RowsAtCompileTime,
- traits<PlainObject>::RowsAtCompileTime),
- ColsAtCompileTime = traits<PlainObject>::ColsAtCompileTime,
- MaxRowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::MaxRowsAtCompileTime,
- traits<PlainObject>::MaxRowsAtCompileTime),
- MaxColsAtCompileTime = traits<PlainObject>::MaxColsAtCompileTime,
- Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
- | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
- CoeffReadCost = 0
- };
-};
+template<typename Lhs, typename Rhs>
+struct traits<MatrixPowerMatrixProduct<Lhs,Rhs> >
+: traits<MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs> >
+{ };
}
template<typename Derived>
diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h
index 0a18fe1c1..28617ff6f 100644
--- a/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h
+++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPowerBase.h
@@ -29,9 +29,29 @@ struct recompose_complex_schur<0>
{ res = (U * (T.template triangularView<Upper>() * U.adjoint())).real(); }
};
-template<typename Derived>
-struct traits<MatrixPowerProductBase<Derived> > : traits<Derived>
-{ };
+template<typename Derived, typename _Lhs, typename _Rhs>
+struct traits<MatrixPowerProductBase<Derived,_Lhs,_Rhs> >
+{
+ typedef MatrixXpr XprKind;
+ typedef typename remove_all<_Lhs>::type Lhs;
+ typedef typename remove_all<_Rhs>::type Rhs;
+ typedef typename remove_all<Derived>::type PlainObject;
+ typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
+ typedef typename promote_storage_type<typename traits<Lhs>::StorageKind,
+ typename traits<Rhs>::StorageKind>::ret StorageKind;
+ typedef typename promote_index_type<typename traits<Lhs>::Index,
+ typename traits<Rhs>::Index>::type Index;
+
+ enum {
+ RowsAtCompileTime = traits<Lhs>::RowsAtCompileTime,
+ ColsAtCompileTime = traits<Rhs>::ColsAtCompileTime,
+ MaxRowsAtCompileTime = traits<Lhs>::MaxRowsAtCompileTime,
+ MaxColsAtCompileTime = traits<Rhs>::MaxColsAtCompileTime,
+ Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
+ | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
+ CoeffReadCost = 0
+ };
+};
template<typename T>
inline int binary_powering_cost(T p, int* squarings)
@@ -219,13 +239,18 @@ void MatrixPowerTriangularAtomic<MatrixType,UpLo>::computeBig(MatrixType& res, R
compute2x2(res, p);
}
-template<typename Derived>
+#define EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(Derived) \
+ typedef MatrixPowerProductBase<Derived, Lhs, Rhs > Base; \
+ EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
+
+template<typename Derived, typename Lhs, typename Rhs>
class MatrixPowerProductBase : public MatrixBase<Derived>
{
public:
typedef MatrixBase<Derived> Base;
- typedef typename Base::PlainObject PlainObject;
EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerProductBase)
+
+ typedef typename Base::PlainObject PlainObject;
inline Index rows() const { return derived().rows(); }
inline Index cols() const { return derived().cols(); }
@@ -247,6 +272,14 @@ class MatrixPowerProductBase : public MatrixBase<Derived>
mutable PlainObject m_result;
};
+template<typename Derived>
+template<typename ProductDerived, typename Lhs, typename Rhs>
+Derived& MatrixBase<Derived>::lazyAssign(const MatrixPowerProductBase<ProductDerived,Lhs,Rhs>& other)
+{
+ other.derived().evalTo(derived());
+ return derived();
+}
+
} // namespace Eigen
#endif // EIGEN_MATRIX_POWER