aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2009-08-11 15:15:06 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2009-08-11 15:15:06 +0200
commitafbd73b5cdc1ce9b8ae54e9dd08332c870cf54d2 (patch)
tree1a7803061dfd11097b94f1e1a54d374d254c3867
parenta4f664251863907604d43be70a41cc4c1dddd42a (diff)
overload operartor* with a ProductBase such that "scalar * (mat * mat)" is optimized
as one could naturally expect
-rw-r--r--Eigen/src/Core/Product.h6
-rw-r--r--Eigen/src/Core/ProductBase.h69
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h2
-rw-r--r--Eigen/src/Core/products/SelfadjointMatrixMatrix.h2
-rw-r--r--Eigen/src/Core/products/SelfadjointMatrixVector.h2
-rw-r--r--Eigen/src/Core/products/TriangularMatrixMatrix.h2
-rw-r--r--Eigen/src/Core/products/TriangularMatrixVector.h2
-rw-r--r--test/product_notemporary.cpp2
-rw-r--r--test/product_symm.cpp5
9 files changed, 76 insertions, 16 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h
index 18f14f75a..610d5c84a 100644
--- a/Eigen/src/Core/Product.h
+++ b/Eigen/src/Core/Product.h
@@ -153,7 +153,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==1 && dst.cols()==1);
dst.coeffRef(0,0) += alpha * (m_lhs.cwise()*m_rhs).sum();
@@ -179,7 +179,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dest, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
{
ei_outer_product_selector<Dest::Flags&RowMajorBit>::run(*this, dest, alpha);
}
@@ -236,7 +236,7 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight };
typedef typename ei_meta_if<int(Side)==OnTheRight,_LhsNested,_RhsNested>::ret MatrixType;
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols());
ei_gemv_selector<Side,int(MatrixType::Flags)&RowMajorBit,
diff --git a/Eigen/src/Core/ProductBase.h b/Eigen/src/Core/ProductBase.h
index bae2302b3..0da046b1e 100644
--- a/Eigen/src/Core/ProductBase.h
+++ b/Eigen/src/Core/ProductBase.h
@@ -100,16 +100,16 @@ class ProductBase : public MatrixBase<Derived>
inline int cols() const { return m_rhs.cols(); }
template<typename Dest>
- inline void evalTo(Dest& dst) const { dst.setZero(); addTo(dst,1); }
+ inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,1); }
template<typename Dest>
- inline void addTo(Dest& dst) const { addTo(dst,1); }
+ inline void addTo(Dest& dst) const { scaleAndAddTo(dst,1); }
template<typename Dest>
- inline void subTo(Dest& dst) const { addTo(dst,-1); }
+ inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-1); }
template<typename Dest>
- inline void addTo(Dest& dst,Scalar alpha) const { derived().addTo(dst,alpha); }
+ inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { derived().scaleAndAddTo(dst,alpha); }
PlainMatrixType eval() const
{
@@ -141,13 +141,68 @@ class ProductBase : public MatrixBase<Derived>
void coeffRef(int);
};
+template<typename NestedProduct>
+class ScaledProduct;
+
+// Note that these two operator* functions are not defined as member
+// functions of ProductBase, because, otherwise we would have to
+// define all overloads defined in MatrixBase. Furthermore, Using
+// "using Base::operator*" would not work with MSVC.
+template<typename Derived,typename Lhs,typename Rhs>
+const ScaledProduct<Derived> operator*(const ProductBase<Derived,Lhs,Rhs>& prod, typename Derived::Scalar x)
+{ return ScaledProduct<Derived>(prod.derived(), x); }
+
+template<typename Derived,typename Lhs,typename Rhs>
+const ScaledProduct<Derived> operator*(typename Derived::Scalar x,const ProductBase<Derived,Lhs,Rhs>& prod)
+{ return ScaledProduct<Derived>(prod.derived(), x); }
+
+template<typename NestedProduct>
+struct ei_traits<ScaledProduct<NestedProduct> >
+ : ei_traits<ProductBase<ScaledProduct<NestedProduct>,
+ typename NestedProduct::_LhsNested,
+ typename NestedProduct::_RhsNested> >
+{};
+
+template<typename NestedProduct>
+class ScaledProduct
+ : public ProductBase<ScaledProduct<NestedProduct>,
+ typename NestedProduct::_LhsNested,
+ typename NestedProduct::_RhsNested>
+{
+ public:
+ typedef ProductBase<ScaledProduct<NestedProduct>,
+ typename NestedProduct::_LhsNested,
+ typename NestedProduct::_RhsNested> Base;
+ typedef typename Base::Scalar Scalar;
+// EIGEN_PRODUCT_PUBLIC_INTERFACE(ScaledProduct)
+
+ ScaledProduct(const NestedProduct& prod, Scalar& x)
+ : Base(prod.lhs(),prod.rhs()), m_prod(prod), m_alpha(x) {}
+
+ template<typename Dest>
+ inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,m_alpha); }
+
+ template<typename Dest>
+ inline void addTo(Dest& dst) const { scaleAndAddTo(dst,m_alpha); }
+
+ template<typename Dest>
+ inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-m_alpha); }
+
+ template<typename Dest>
+ inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { m_prod.derived().scaleAndAddTo(dst,alpha); }
+
+ protected:
+ const NestedProduct& m_prod;
+ Scalar m_alpha;
+};
+
/** \internal
* Overloaded to perform an efficient C = (A*B).lazy() */
template<typename Derived>
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other)
{
- other.evalTo(derived()); return derived();
+ other.derived().evalTo(derived()); return derived();
}
/** \internal
@@ -157,7 +212,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::operator+=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0,
EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{
- other._expression().addTo(derived()); return derived();
+ other._expression().derived().addTo(derived()); return derived();
}
/** \internal
@@ -167,7 +222,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::operator-=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0,
EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{
- other._expression().subTo(derived()); return derived();
+ other._expression().derived().subTo(derived()); return derived();
}
#endif // EIGEN_PRODUCTBASE_H
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index ff0f2c1b4..8b3b13266 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -137,7 +137,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
index 358da3752..5e025b90b 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
@@ -375,7 +375,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
};
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h
index f0004cdb9..c2c33d5b8 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixVector.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h
@@ -175,7 +175,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h
index c2ee39e79..701ccb644 100644
--- a/Eigen/src/Core/products/TriangularMatrixMatrix.h
+++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h
@@ -333,7 +333,7 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h
index a21afa2f6..620b090b9 100644
--- a/Eigen/src/Core/products/TriangularMatrixVector.h
+++ b/Eigen/src/Core/products/TriangularMatrixVector.h
@@ -130,7 +130,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
diff --git a/test/product_notemporary.cpp b/test/product_notemporary.cpp
index f4311e495..d5d996e49 100644
--- a/test/product_notemporary.cpp
+++ b/test/product_notemporary.cpp
@@ -79,7 +79,7 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
VERIFY_EVALUATION_COUNT( m3 = (s1 * m1 * s2 * (m1*s3+m2*s2).adjoint()).lazy(), 1);
VERIFY_EVALUATION_COUNT( m3 = ((s1 * m1).adjoint() * s2 * m2).lazy(), 0);
VERIFY_EVALUATION_COUNT( m3 -= (s1 * (-m1*s3).adjoint() * (s2 * m2 * s3)).lazy(), 0);
- VERIFY_EVALUATION_COUNT( m3 -= (s1 * (m1.transpose() * m2)).lazy(), 1);
+ VERIFY_EVALUATION_COUNT( m3 -= (s1 * (m1.transpose() * m2)).lazy(), 0);
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) += (-m1.block(r0,c0,r1,c1) * (s2*m2.block(r0,c0,r1,c1)).adjoint()).lazy() ), 0);
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) -= (s1 * m1.block(r0,c0,r1,c1) * m2.block(c0,r0,c1,r1)).lazy() ), 0);
diff --git a/test/product_symm.cpp b/test/product_symm.cpp
index 88bac878b..1300928a2 100644
--- a/test/product_symm.cpp
+++ b/test/product_symm.cpp
@@ -94,6 +94,11 @@ template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, in
VERIFY_IS_APPROX(rhs12 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate(),
rhs13 = (s1*m1.adjoint()) * (s2*rhs3).conjugate());
+
+ m2 = m1.template triangularView<UpperTriangular>(); rhs13 = rhs12;
+ VERIFY_IS_APPROX(rhs12 += (s1 * ((m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate())).lazy(),
+ rhs13 += (s1*m1.adjoint()) * (s2*rhs3).conjugate());
+
// test matrix * selfadjoint
symm_extra<OtherSize>::run(m1,m2,rhs2,rhs22,rhs23,s1,s2);