diff options
author | Gael Guennebaud <g.gael@free.fr> | 2009-08-11 15:15:06 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2009-08-11 15:15:06 +0200 |
commit | afbd73b5cdc1ce9b8ae54e9dd08332c870cf54d2 (patch) | |
tree | 1a7803061dfd11097b94f1e1a54d374d254c3867 /Eigen/src/Core | |
parent | a4f664251863907604d43be70a41cc4c1dddd42a (diff) |
overload operartor* with a ProductBase such that "scalar * (mat * mat)" is optimized
as one could naturally expect
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/Product.h | 6 | ||||
-rw-r--r-- | Eigen/src/Core/ProductBase.h | 69 | ||||
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 2 | ||||
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixMatrix.h | 2 | ||||
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixVector.h | 2 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixMatrix.h | 2 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixVector.h | 2 |
7 files changed, 70 insertions, 15 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 18f14f75a..610d5c84a 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -153,7 +153,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct> GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template<typename Dest> void addTo(Dest& dst, Scalar alpha) const + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==1 && dst.cols()==1); dst.coeffRef(0,0) += alpha * (m_lhs.cwise()*m_rhs).sum(); @@ -179,7 +179,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct> GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template<typename Dest> void addTo(Dest& dest, Scalar alpha) const + template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const { ei_outer_product_selector<Dest::Flags&RowMajorBit>::run(*this, dest, alpha); } @@ -236,7 +236,7 @@ class GeneralProduct<Lhs, Rhs, GemvProduct> enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight }; typedef typename ei_meta_if<int(Side)==OnTheRight,_LhsNested,_RhsNested>::ret MatrixType; - template<typename Dest> void addTo(Dest& dst, Scalar alpha) const + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols()); ei_gemv_selector<Side,int(MatrixType::Flags)&RowMajorBit, diff --git a/Eigen/src/Core/ProductBase.h b/Eigen/src/Core/ProductBase.h index bae2302b3..0da046b1e 100644 --- a/Eigen/src/Core/ProductBase.h +++ b/Eigen/src/Core/ProductBase.h @@ -100,16 +100,16 @@ class ProductBase : public MatrixBase<Derived> inline int cols() const { return m_rhs.cols(); } template<typename Dest> - inline void evalTo(Dest& dst) const { dst.setZero(); addTo(dst,1); } + inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,1); } template<typename Dest> - inline void addTo(Dest& dst) const { addTo(dst,1); } + inline void addTo(Dest& dst) const { scaleAndAddTo(dst,1); } template<typename Dest> - inline void subTo(Dest& dst) const { addTo(dst,-1); } + inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-1); } template<typename Dest> - inline void addTo(Dest& dst,Scalar alpha) const { derived().addTo(dst,alpha); } + inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { derived().scaleAndAddTo(dst,alpha); } PlainMatrixType eval() const { @@ -141,13 +141,68 @@ class ProductBase : public MatrixBase<Derived> void coeffRef(int); }; +template<typename NestedProduct> +class ScaledProduct; + +// Note that these two operator* functions are not defined as member +// functions of ProductBase, because, otherwise we would have to +// define all overloads defined in MatrixBase. Furthermore, Using +// "using Base::operator*" would not work with MSVC. +template<typename Derived,typename Lhs,typename Rhs> +const ScaledProduct<Derived> operator*(const ProductBase<Derived,Lhs,Rhs>& prod, typename Derived::Scalar x) +{ return ScaledProduct<Derived>(prod.derived(), x); } + +template<typename Derived,typename Lhs,typename Rhs> +const ScaledProduct<Derived> operator*(typename Derived::Scalar x,const ProductBase<Derived,Lhs,Rhs>& prod) +{ return ScaledProduct<Derived>(prod.derived(), x); } + +template<typename NestedProduct> +struct ei_traits<ScaledProduct<NestedProduct> > + : ei_traits<ProductBase<ScaledProduct<NestedProduct>, + typename NestedProduct::_LhsNested, + typename NestedProduct::_RhsNested> > +{}; + +template<typename NestedProduct> +class ScaledProduct + : public ProductBase<ScaledProduct<NestedProduct>, + typename NestedProduct::_LhsNested, + typename NestedProduct::_RhsNested> +{ + public: + typedef ProductBase<ScaledProduct<NestedProduct>, + typename NestedProduct::_LhsNested, + typename NestedProduct::_RhsNested> Base; + typedef typename Base::Scalar Scalar; +// EIGEN_PRODUCT_PUBLIC_INTERFACE(ScaledProduct) + + ScaledProduct(const NestedProduct& prod, Scalar& x) + : Base(prod.lhs(),prod.rhs()), m_prod(prod), m_alpha(x) {} + + template<typename Dest> + inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,m_alpha); } + + template<typename Dest> + inline void addTo(Dest& dst) const { scaleAndAddTo(dst,m_alpha); } + + template<typename Dest> + inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-m_alpha); } + + template<typename Dest> + inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { m_prod.derived().scaleAndAddTo(dst,alpha); } + + protected: + const NestedProduct& m_prod; + Scalar m_alpha; +}; + /** \internal * Overloaded to perform an efficient C = (A*B).lazy() */ template<typename Derived> template<typename ProductDerived, typename Lhs, typename Rhs> Derived& MatrixBase<Derived>::lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other) { - other.evalTo(derived()); return derived(); + other.derived().evalTo(derived()); return derived(); } /** \internal @@ -157,7 +212,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs> Derived& MatrixBase<Derived>::operator+=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other) { - other._expression().addTo(derived()); return derived(); + other._expression().derived().addTo(derived()); return derived(); } /** \internal @@ -167,7 +222,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs> Derived& MatrixBase<Derived>::operator-=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other) { - other._expression().subTo(derived()); return derived(); + other._expression().derived().subTo(derived()); return derived(); } #endif // EIGEN_PRODUCTBASE_H diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index ff0f2c1b4..8b3b13266 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -137,7 +137,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template<typename Dest> void addTo(Dest& dst, Scalar alpha) const + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h index 358da3752..5e025b90b 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h @@ -375,7 +375,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false> RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit }; - template<typename Dest> void addTo(Dest& dst, Scalar alpha) const + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h index f0004cdb9..c2c33d5b8 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixVector.h +++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h @@ -175,7 +175,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template<typename Dest> void addTo(Dest& dst, Scalar alpha) const + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h index c2ee39e79..701ccb644 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h @@ -333,7 +333,7 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template<typename Dest> void addTo(Dest& dst, Scalar alpha) const + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h index a21afa2f6..620b090b9 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector.h +++ b/Eigen/src/Core/products/TriangularMatrixVector.h @@ -130,7 +130,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template<typename Dest> void addTo(Dest& dst, Scalar alpha) const + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); |