aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r--Eigen/src/Core/Product.h6
-rw-r--r--Eigen/src/Core/ProductBase.h69
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h2
-rw-r--r--Eigen/src/Core/products/SelfadjointMatrixMatrix.h2
-rw-r--r--Eigen/src/Core/products/SelfadjointMatrixVector.h2
-rw-r--r--Eigen/src/Core/products/TriangularMatrixMatrix.h2
-rw-r--r--Eigen/src/Core/products/TriangularMatrixVector.h2
7 files changed, 70 insertions, 15 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h
index 18f14f75a..610d5c84a 100644
--- a/Eigen/src/Core/Product.h
+++ b/Eigen/src/Core/Product.h
@@ -153,7 +153,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==1 && dst.cols()==1);
dst.coeffRef(0,0) += alpha * (m_lhs.cwise()*m_rhs).sum();
@@ -179,7 +179,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dest, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
{
ei_outer_product_selector<Dest::Flags&RowMajorBit>::run(*this, dest, alpha);
}
@@ -236,7 +236,7 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight };
typedef typename ei_meta_if<int(Side)==OnTheRight,_LhsNested,_RhsNested>::ret MatrixType;
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols());
ei_gemv_selector<Side,int(MatrixType::Flags)&RowMajorBit,
diff --git a/Eigen/src/Core/ProductBase.h b/Eigen/src/Core/ProductBase.h
index bae2302b3..0da046b1e 100644
--- a/Eigen/src/Core/ProductBase.h
+++ b/Eigen/src/Core/ProductBase.h
@@ -100,16 +100,16 @@ class ProductBase : public MatrixBase<Derived>
inline int cols() const { return m_rhs.cols(); }
template<typename Dest>
- inline void evalTo(Dest& dst) const { dst.setZero(); addTo(dst,1); }
+ inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,1); }
template<typename Dest>
- inline void addTo(Dest& dst) const { addTo(dst,1); }
+ inline void addTo(Dest& dst) const { scaleAndAddTo(dst,1); }
template<typename Dest>
- inline void subTo(Dest& dst) const { addTo(dst,-1); }
+ inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-1); }
template<typename Dest>
- inline void addTo(Dest& dst,Scalar alpha) const { derived().addTo(dst,alpha); }
+ inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { derived().scaleAndAddTo(dst,alpha); }
PlainMatrixType eval() const
{
@@ -141,13 +141,68 @@ class ProductBase : public MatrixBase<Derived>
void coeffRef(int);
};
+template<typename NestedProduct>
+class ScaledProduct;
+
+// Note that these two operator* functions are not defined as member
+// functions of ProductBase, because, otherwise we would have to
+// define all overloads defined in MatrixBase. Furthermore, Using
+// "using Base::operator*" would not work with MSVC.
+template<typename Derived,typename Lhs,typename Rhs>
+const ScaledProduct<Derived> operator*(const ProductBase<Derived,Lhs,Rhs>& prod, typename Derived::Scalar x)
+{ return ScaledProduct<Derived>(prod.derived(), x); }
+
+template<typename Derived,typename Lhs,typename Rhs>
+const ScaledProduct<Derived> operator*(typename Derived::Scalar x,const ProductBase<Derived,Lhs,Rhs>& prod)
+{ return ScaledProduct<Derived>(prod.derived(), x); }
+
+template<typename NestedProduct>
+struct ei_traits<ScaledProduct<NestedProduct> >
+ : ei_traits<ProductBase<ScaledProduct<NestedProduct>,
+ typename NestedProduct::_LhsNested,
+ typename NestedProduct::_RhsNested> >
+{};
+
+template<typename NestedProduct>
+class ScaledProduct
+ : public ProductBase<ScaledProduct<NestedProduct>,
+ typename NestedProduct::_LhsNested,
+ typename NestedProduct::_RhsNested>
+{
+ public:
+ typedef ProductBase<ScaledProduct<NestedProduct>,
+ typename NestedProduct::_LhsNested,
+ typename NestedProduct::_RhsNested> Base;
+ typedef typename Base::Scalar Scalar;
+// EIGEN_PRODUCT_PUBLIC_INTERFACE(ScaledProduct)
+
+ ScaledProduct(const NestedProduct& prod, Scalar& x)
+ : Base(prod.lhs(),prod.rhs()), m_prod(prod), m_alpha(x) {}
+
+ template<typename Dest>
+ inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,m_alpha); }
+
+ template<typename Dest>
+ inline void addTo(Dest& dst) const { scaleAndAddTo(dst,m_alpha); }
+
+ template<typename Dest>
+ inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-m_alpha); }
+
+ template<typename Dest>
+ inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { m_prod.derived().scaleAndAddTo(dst,alpha); }
+
+ protected:
+ const NestedProduct& m_prod;
+ Scalar m_alpha;
+};
+
/** \internal
* Overloaded to perform an efficient C = (A*B).lazy() */
template<typename Derived>
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other)
{
- other.evalTo(derived()); return derived();
+ other.derived().evalTo(derived()); return derived();
}
/** \internal
@@ -157,7 +212,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::operator+=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0,
EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{
- other._expression().addTo(derived()); return derived();
+ other._expression().derived().addTo(derived()); return derived();
}
/** \internal
@@ -167,7 +222,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::operator-=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0,
EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{
- other._expression().subTo(derived()); return derived();
+ other._expression().derived().subTo(derived()); return derived();
}
#endif // EIGEN_PRODUCTBASE_H
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index ff0f2c1b4..8b3b13266 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -137,7 +137,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
index 358da3752..5e025b90b 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
@@ -375,7 +375,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
};
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h
index f0004cdb9..c2c33d5b8 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixVector.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h
@@ -175,7 +175,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h
index c2ee39e79..701ccb644 100644
--- a/Eigen/src/Core/products/TriangularMatrixMatrix.h
+++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h
@@ -333,7 +333,7 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h
index a21afa2f6..620b090b9 100644
--- a/Eigen/src/Core/products/TriangularMatrixVector.h
+++ b/Eigen/src/Core/products/TriangularMatrixVector.h
@@ -130,7 +130,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
- template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());