From 3cf5bb31f6b6e3b6b8f229ed1658cb867fe6e8f5 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 3 Aug 2009 16:05:15 +0200 Subject: * Bye bye MultiplierBase, extend a bit AnyMatrixBase to allow =, +=, and -= * This probably makes ReturnByValue needless --- Eigen/src/Core/BandMatrix.h | 2 +- Eigen/src/Core/DiagonalMatrix.h | 17 +++++++- Eigen/src/Core/Matrix.h | 1 + Eigen/src/Core/MatrixBase.h | 52 ++++++++++++++++-------- Eigen/src/Core/Product.h | 2 +- Eigen/src/Core/ReturnByValue.h | 32 --------------- Eigen/src/Core/SelfAdjointView.h | 48 ++++++++++++---------- Eigen/src/Core/TriangularMatrix.h | 12 +++--- Eigen/src/Core/products/TriangularMatrixMatrix.h | 19 +++++---- Eigen/src/Core/products/TriangularMatrixVector.h | 19 +++++---- Eigen/src/Core/util/ForwardDeclarations.h | 1 - 11 files changed, 108 insertions(+), 97 deletions(-) (limited to 'Eigen') diff --git a/Eigen/src/Core/BandMatrix.h b/Eigen/src/Core/BandMatrix.h index 2da463afc..c22696992 100644 --- a/Eigen/src/Core/BandMatrix.h +++ b/Eigen/src/Core/BandMatrix.h @@ -57,7 +57,7 @@ struct ei_traits > }; template -class BandMatrix : public MultiplierBase > +class BandMatrix : public AnyMatrixBase > { public: diff --git a/Eigen/src/Core/DiagonalMatrix.h b/Eigen/src/Core/DiagonalMatrix.h index 5fc80c92b..ebbed15d4 100644 --- a/Eigen/src/Core/DiagonalMatrix.h +++ b/Eigen/src/Core/DiagonalMatrix.h @@ -27,7 +27,7 @@ #define EIGEN_DIAGONALMATRIX_H template -class DiagonalBase : public MultiplierBase +class DiagonalBase : public AnyMatrixBase { public: typedef typename ei_traits::DiagonalVectorType DiagonalVectorType; @@ -52,6 +52,12 @@ class DiagonalBase : public MultiplierBase DenseMatrixType toDenseMatrix() const { return derived(); } template void evalToDense(MatrixBase &other) const; + template + void addToDense(MatrixBase &other) const + { other.diagonal() += diagonal(); } + template + void subToDense(MatrixBase &other) const + { other.diagonal() -= diagonal(); } inline const DiagonalVectorType& diagonal() const { return derived().diagonal(); } inline DiagonalVectorType& diagonal() { return derived().diagonal(); } @@ -84,6 +90,7 @@ void DiagonalBase::evalToDense(MatrixBase &other) const */ template struct ei_traits > + : ei_traits > { typedef Matrix<_Scalar,_Size,1> DiagonalVectorType; }; @@ -170,6 +177,14 @@ template struct ei_traits > { typedef _DiagonalVectorType DiagonalVectorType; + typedef typename DiagonalVectorType::Scalar Scalar; + enum { + RowsAtCompileTime = DiagonalVectorType::SizeAtCompileTime, + ColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime, + MaxRowsAtCompileTime = DiagonalVectorType::SizeAtCompileTime, + MaxColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime, + Flags = 0 + }; }; template diff --git a/Eigen/src/Core/Matrix.h b/Eigen/src/Core/Matrix.h index 8937596f2..c31acabca 100644 --- a/Eigen/src/Core/Matrix.h +++ b/Eigen/src/Core/Matrix.h @@ -462,6 +462,7 @@ class Matrix : m_storage(other.derived().rows() * other.derived().cols(), other.derived().rows(), other.derived().cols()) { _check_template_params(); + resize(other.rows(), other.cols()); *this = other; } diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index b881c09c3..30cfbb192 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -37,19 +37,31 @@ */ template struct AnyMatrixBase { + typedef typename ei_plain_matrix_type::type PlainMatrixType; + Derived& derived() { return *static_cast(this); } const Derived& derived() const { return *static_cast(this); } -}; -/** Common base class for all classes T such that there are overloaded operator* allowing to - * multiply a MatrixBase by a T on both sides. - * - * In other words, an AnyMatrixBase object is an object that can be multiplied a MatrixBase, the result being again a MatrixBase. - * - * Besides MatrixBase-derived classes, this also includes certain special matrix classes, such as diagonal matrices. - */ -template struct MultiplierBase : public AnyMatrixBase -{ - using AnyMatrixBase::derived; + /** \returns the number of rows. \sa cols(), RowsAtCompileTime */ + inline int rows() const { return derived().rows(); } + /** \returns the number of columns. \sa rows(), ColsAtCompileTime*/ + inline int cols() const { return derived().cols(); } + + template inline void evalTo(Dest& dst) const + { derived().evalTo(dst); } + + template inline void addToDense(Dest& dst) const + { + typename Dest::PlainMatrixType res(rows(),cols()); + evalToDense(res); + dst += res; + } + + template inline void subToDense(Dest& dst) const + { + typename Dest::PlainMatrixType res(rows(),cols()); + evalToDense(res); + dst -= res; + } }; /** \class MatrixBase @@ -79,7 +91,7 @@ template struct MultiplierBase : public AnyMatrixBase */ template class MatrixBase #ifndef EIGEN_PARSED_BY_DOXYGEN - : public MultiplierBase, + : public AnyMatrixBase, public ei_special_scalar_op_base::Scalar, typename NumTraits::Scalar>::Real> #endif // not EIGEN_PARSED_BY_DOXYGEN @@ -298,12 +310,16 @@ template class MatrixBase Derived& operator=(const AnyMatrixBase &other) { other.derived().evalToDense(derived()); return derived(); } + template + Derived& operator+=(const AnyMatrixBase &other) + { other.derived().addToDense(derived()); return derived(); } + + template + Derived& operator-=(const AnyMatrixBase &other) + { other.derived().subToDense(derived()); return derived(); } + template Derived& operator=(const ReturnByValue& func); - template - Derived& operator+=(const ReturnByValue& func); - template - Derived& operator-=(const ReturnByValue& func); #ifndef EIGEN_PARSED_BY_DOXYGEN /** Copies \a other into *this without evaluating other. \returns a reference to *this. */ @@ -410,7 +426,7 @@ template class MatrixBase operator*(const MatrixBase &other) const; template - Derived& operator*=(const MultiplierBase& other); + Derived& operator*=(const AnyMatrixBase& other); template const DiagonalProduct @@ -645,7 +661,7 @@ template class MatrixBase void visit(Visitor& func) const; #ifndef EIGEN_PARSED_BY_DOXYGEN - using MultiplierBase::derived; + using AnyMatrixBase::derived; inline Derived& const_cast_derived() const { return *static_cast(const_cast(this)); } #endif // not EIGEN_PARSED_BY_DOXYGEN diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index ff45cba3c..1a32eb5de 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -294,7 +294,7 @@ MatrixBase::operator*(const MatrixBase &other) const template template inline Derived & -MatrixBase::operator*=(const MultiplierBase &other) +MatrixBase::operator*=(const AnyMatrixBase &other) { return derived() = derived() * other.derived(); } diff --git a/Eigen/src/Core/ReturnByValue.h b/Eigen/src/Core/ReturnByValue.h index 58b205edc..3f2b478ff 100644 --- a/Eigen/src/Core/ReturnByValue.h +++ b/Eigen/src/Core/ReturnByValue.h @@ -48,14 +48,6 @@ template class ReturnByValue public: template inline void evalTo(Dest& dst) const { static_cast(this)->evalTo(dst); } - template inline void addTo(Dest& dst) const - { static_cast(this)->_addTo(dst); } - template inline void subTo(Dest& dst) const - { static_cast(this)->_subTo(dst); } - template inline void _addTo(Dest& dst) const - { EvalType res; evalTo(res); dst += res; } - template inline void _subTo(Dest& dst) const - { EvalType res; evalTo(res); dst -= res; } }; template @@ -68,14 +60,6 @@ template inline void evalTo(Dest& dst) const { static_cast(this)->evalTo(dst); } - template inline void addTo(Dest& dst) const - { static_cast(this)->_addTo(dst); } - template inline void subTo(Dest& dst) const - { static_cast(this)->_subTo(dst); } - template inline void _addTo(Dest& dst) const - { EvalType res; evalTo(res); dst += res; } - template inline void _subTo(Dest& dst) const - { EvalType res; evalTo(res); dst -= res; } inline int rows() const { return static_cast(this)->rows(); } inline int cols() const { return static_cast(this)->cols(); } }; @@ -88,20 +72,4 @@ Derived& MatrixBase::operator=(const ReturnByValue -template -Derived& MatrixBase::operator+=(const ReturnByValue& other) -{ - other.addTo(derived()); - return derived(); -} - -template -template -Derived& MatrixBase::operator-=(const ReturnByValue& other) -{ - other.subTo(derived()); - return derived(); -} - #endif // EIGEN_RETURNBYVALUE_H diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index c21f3a377..0a4ba17c0 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -55,7 +55,7 @@ struct ei_traits > : ei_traits -struct ei_selfadjoint_product_returntype; +struct SelfadjointProductMatrix; // FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ?? template class SelfAdjointView @@ -100,20 +100,20 @@ template class SelfAdjointView /** Efficient self-adjoint matrix times vector/matrix product */ template - ei_selfadjoint_product_returntype + SelfadjointProductMatrix operator*(const MatrixBase& rhs) const { - return ei_selfadjoint_product_returntype + return SelfadjointProductMatrix (m_matrix, rhs.derived()); } /** Efficient vector/matrix times self-adjoint matrix product */ template friend - ei_selfadjoint_product_returntype + SelfadjointProductMatrix operator*(const MatrixBase& lhs, const SelfAdjointView& rhs) { - return ei_selfadjoint_product_returntype + return SelfadjointProductMatrix (lhs.derived(),rhs.m_matrix); } @@ -201,10 +201,13 @@ struct ei_triangular_assignment_selector -struct ei_selfadjoint_product_returntype - : public ReturnByValue, - Matrix::Scalar, - Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > +struct ei_traits > + : ei_traits::Scalar,Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > +{}; + +template +struct SelfadjointProductMatrix + : public AnyMatrixBase > { typedef typename Lhs::Scalar Scalar; @@ -224,19 +227,19 @@ struct ei_selfadjoint_product_returntype LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit) }; - ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs) + SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {} inline int rows() const { return m_lhs.rows(); } inline int cols() const { return m_rhs.cols(); } - template inline void _addTo(Dest& dst) const + template inline void addToDense(Dest& dst) const { evalTo(dst,1); } - template inline void _subTo(Dest& dst) const + template inline void subToDense(Dest& dst) const { evalTo(dst,-1); } - template void evalTo(Dest& dst) const + template void evalToDense(Dest& dst) const { dst.setZero(); evalTo(dst,1); @@ -272,12 +275,15 @@ struct ei_selfadjoint_product_returntype ***************************************************************************/ template -struct ei_selfadjoint_product_returntype - : public ReturnByValue, - Matrix::Scalar, - Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > +struct ei_traits > + : ei_traits::Scalar,Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > +{}; + +template +struct SelfadjointProductMatrix + : public AnyMatrixBase > { - ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs) + SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {} @@ -305,12 +311,12 @@ struct ei_selfadjoint_product_returntype RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit }; - template inline void _addTo(Dest& dst) const + template inline void addToDense(Dest& dst) const { evalTo(dst,1); } - template inline void _subTo(Dest& dst) const + template inline void subToDense(Dest& dst) const { evalTo(dst,-1); } - template void evalTo(Dest& dst) const + template void evalToDense(Dest& dst) const { dst.setZero(); evalTo(dst,1); diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h index 8b6c9a23b..c262ea7a7 100644 --- a/Eigen/src/Core/TriangularMatrix.h +++ b/Eigen/src/Core/TriangularMatrix.h @@ -43,7 +43,7 @@ * * \sa MatrixBase::part() */ -template class TriangularBase : public MultiplierBase +template class TriangularBase : public AnyMatrixBase { public: @@ -145,7 +145,7 @@ struct ei_traits > : ei_traits template -struct ei_triangular_product_returntype; +struct TriangularProduct; template class TriangularView : public TriangularBase > @@ -253,20 +253,20 @@ template class TriangularView /** Efficient triangular matrix times vector/matrix product */ template - ei_triangular_product_returntype + TriangularProduct operator*(const MatrixBase& rhs) const { - return ei_triangular_product_returntype + return TriangularProduct (m_matrix, rhs.derived()); } /** Efficient vector/matrix times triangular matrix product */ template friend - ei_triangular_product_returntype + TriangularProduct operator*(const MatrixBase& lhs, const TriangularView& rhs) { - return ei_triangular_product_returntype + return TriangularProduct (lhs.derived(),rhs.m_matrix); } diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h index ce18941ee..f69c04365 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h @@ -321,12 +321,15 @@ struct ei_product_triangular_matrix_matrix -struct ei_triangular_product_returntype - : public ReturnByValue, - Matrix::Scalar, - Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > +struct ei_traits > + : ei_traits::Scalar,Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > +{}; + +template +struct TriangularProduct + : public AnyMatrixBase > { - ei_triangular_product_returntype(const Lhs& lhs, const Rhs& rhs) + TriangularProduct(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {} @@ -347,12 +350,12 @@ struct ei_triangular_product_returntype::type _ActualRhsType; - template inline void _addTo(Dest& dst) const + template inline void addToDense(Dest& dst) const { evalTo(dst,1); } - template inline void _subTo(Dest& dst) const + template inline void subToDense(Dest& dst) const { evalTo(dst,-1); } - template void evalTo(Dest& dst) const + template void evalToDense(Dest& dst) const { dst.resize(m_lhs.rows(), m_rhs.cols()); dst.setZero(); diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h index 18d76b95c..42239fac0 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector.h +++ b/Eigen/src/Core/products/TriangularMatrixVector.h @@ -118,10 +118,13 @@ struct ei_product_triangular_vector_selector -struct ei_triangular_product_returntype - : public ReturnByValue, - Matrix::Scalar, - Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > +struct ei_traits > + : ei_traits::Scalar,Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > +{}; + +template +struct TriangularProduct + : public AnyMatrixBase > { typedef typename Lhs::Scalar Scalar; @@ -137,19 +140,19 @@ struct ei_triangular_product_returntype typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; typedef typename ei_cleantype::type _ActualRhsType; - ei_triangular_product_returntype(const Lhs& lhs, const Rhs& rhs) + TriangularProduct(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {} inline int rows() const { return m_lhs.rows(); } inline int cols() const { return m_rhs.cols(); } - template inline void _addTo(Dest& dst) const + template inline void addToDense(Dest& dst) const { evalTo(dst,1); } - template inline void _subTo(Dest& dst) const + template inline void subToDense(Dest& dst) const { evalTo(dst,-1); } - template void evalTo(Dest& dst) const + template void evalToDense(Dest& dst) const { dst.setZero(); evalTo(dst,1); diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index b4fbae28c..310d0fbde 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -29,7 +29,6 @@ template struct ei_traits; template struct NumTraits; template struct AnyMatrixBase; -template struct MultiplierBase; template