aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2013-02-25 13:31:42 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2013-02-25 13:31:42 +0100
commit5a0c5c039322eabbe3ef73a97f33ac85c4505da2 (patch)
tree128f0911a2d56de1dbddabb2340e94141337005c /Eigen
parent96ad13abba9ca90fdf02ad34c05ab3d10667d7a5 (diff)
Fix bug #483: optimize outer-products to skip setZero and a scalar multiple when not needed.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/Core/GeneralProduct.h40
1 files changed, 33 insertions, 7 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h
index 9abc7b286..06fb8e6c0 100644
--- a/Eigen/src/Core/GeneralProduct.h
+++ b/Eigen/src/Core/GeneralProduct.h
@@ -243,36 +243,62 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
EIGEN_STATIC_ASSERT((internal::is_same<typename Lhs::RealScalar, typename Rhs::RealScalar>::value),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
}
+
+ struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } };
+ struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } };
+ struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } };
+ struct adds {
+ Scalar m_scale;
+ adds(const Scalar& s) : m_scale(s) {}
+ template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const {
+ dst.const_cast_derived() += m_scale * src;
+ }
+ };
+
+ template<typename Dest>
+ inline void evalTo(Dest& dest) const {
+ internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, set());
+ }
+
+ template<typename Dest>
+ inline void addTo(Dest& dest) const {
+ internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, add());
+ }
+
+ template<typename Dest>
+ inline void subTo(Dest& dest) const {
+ internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, sub());
+ }
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
{
- internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, alpha);
+ internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, adds(alpha));
}
};
namespace internal {
template<> struct outer_product_selector<ColMajor> {
- template<typename ProductType, typename Dest>
- static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) {
+ template<typename ProductType, typename Dest, typename Func>
+ static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, const Func& func) {
typedef typename Dest::Index Index;
// FIXME make sure lhs is sequentially stored
// FIXME not very good if rhs is real and lhs complex while alpha is real too
const Index cols = dest.cols();
for (Index j=0; j<cols; ++j)
- dest.col(j) += (alpha * prod.rhs().coeff(j)) * prod.lhs();
+ func(dest.col(j), prod.rhs().coeff(j) * prod.lhs());
}
};
template<> struct outer_product_selector<RowMajor> {
- template<typename ProductType, typename Dest>
- static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) {
+ template<typename ProductType, typename Dest, typename Func>
+ static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, const Func& func) {
typedef typename Dest::Index Index;
// FIXME make sure rhs is sequentially stored
// FIXME not very good if lhs is real and rhs complex while alpha is real too
const Index rows = dest.rows();
for (Index i=0; i<rows; ++i)
- dest.row(i) += (alpha * prod.lhs().coeff(i)) * prod.rhs();
+ func(dest.row(i), prod.lhs().coeff(i) * prod.rhs());
}
};