diff options
author | 2014-02-21 16:27:24 +0100 | |
---|---|---|
committer | 2014-02-21 16:27:24 +0100 | |
commit | 728c3d2cb955a255cae5515197ae65dc83209509 (patch) | |
tree | 04a0ee62f3a16432c58fb45d122bd4c1b60c60fb /Eigen/src/Core | |
parent | af31b6c37a3b4b32c8075d94b39a78108f12fd31 (diff) |
Get rid of GeneralProduct for outer-products, and get rid of ScaledProduct
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/GeneralProduct.h | 3 | ||||
-rw-r--r-- | Eigen/src/Core/ProductBase.h | 3 | ||||
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 71 |
3 files changed, 66 insertions, 11 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h index f823ff251..4c0fc7f63 100644 --- a/Eigen/src/Core/GeneralProduct.h +++ b/Eigen/src/Core/GeneralProduct.h @@ -247,6 +247,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct> * Implementation of Outer Vector Vector Product ***********************************************************************/ +#ifndef EIGEN_TEST_EVALUATORS namespace internal { // Column major @@ -326,6 +327,8 @@ class GeneralProduct<Lhs, Rhs, OuterProduct> } }; +#endif // EIGEN_TEST_EVALUATORS + /*********************************************************************** * Implementation of General Matrix Vector Product ***********************************************************************/ diff --git a/Eigen/src/Core/ProductBase.h b/Eigen/src/Core/ProductBase.h index a494b5f87..f6b719d19 100644 --- a/Eigen/src/Core/ProductBase.h +++ b/Eigen/src/Core/ProductBase.h @@ -174,6 +174,7 @@ class ProductBase : public MatrixBase<Derived> mutable PlainObject m_result; }; +#ifndef EIGEN_TEST_EVALUATORS // here we need to overload the nested rule for products // such that the nested type is a const reference to a plain matrix namespace internal { @@ -263,6 +264,8 @@ class ScaledProduct Scalar m_alpha; }; +#endif // EIGEN_TEST_EVALUATORS + /** \internal * Overloaded to perform an efficient C = (A*B).lazy() */ template<typename Derived> diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index cf612d58a..93ae5f5f5 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -17,7 +17,11 @@ namespace Eigen { namespace internal { -// Like more general binary expressions, products need their own evaluator: +/** \internal + * \class product_evaluator + * Products need their own evaluator with more template arguments allowing for + * easier partial template specializations. + */ template< typename T, int ProductTag = internal::product_type<typename T::Lhs,typename T::Rhs>::ret, typename LhsShape = typename evaluator_traits<typename T::Lhs>::Shape, @@ -26,6 +30,14 @@ template< typename T, typename RhsScalar = typename T::Rhs::Scalar > struct product_evaluator; +/** \internal + * Evaluator of a product expression. + * Since products require special treatments to handle all possible cases, + * we simply deffer the evaluation logic to a product_evaluator class + * which offers more partial specialization possibilities. + * + * \sa class product_evaluator + */ template<typename Lhs, typename Rhs, int Options> struct evaluator<Product<Lhs, Rhs, Options> > : public product_evaluator<Product<Lhs, Rhs, Options> > @@ -40,7 +52,7 @@ struct evaluator<Product<Lhs, Rhs, Options> > }; // Catch scalar * ( A * B ) and transform it to (A*scalar) * B -// TODO we should apply that rule if that's really helpful +// TODO we should apply that rule only if that's really helpful template<typename Lhs, typename Rhs, typename Scalar> struct evaluator<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > > : public evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,const Lhs>, Rhs, DefaultProduct> > @@ -66,7 +78,7 @@ struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> > typedef evaluator type; typedef evaluator nestedType; -// + evaluator(const XprType& xpr) : Base(Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>( Product<Lhs, Rhs, LazyProduct>(xpr.nestedExpression().lhs(), xpr.nestedExpression().rhs()), @@ -183,38 +195,75 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,InnerProduct> }; +/*********************************************************************** +* Implementation of outer dense * dense vector product +***********************************************************************/ + +// Column major result +template<typename Dst, typename Lhs, typename Rhs, typename Func> +EIGEN_DONT_INLINE void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const false_type&) +{ + typedef typename Dst::Index Index; + // FIXME make sure lhs is sequentially stored + // FIXME not very good if rhs is real and lhs complex while alpha is real too + // FIXME we should probably build an evaluator for dst and rhs + const Index cols = dst.cols(); + for (Index j=0; j<cols; ++j) + func(dst.col(j), rhs.coeff(j) * lhs); +} + +// Row major result +template<typename Dst, typename Lhs, typename Rhs, typename Func> +EIGEN_DONT_INLINE void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const true_type&) { + typedef typename Dst::Index Index; + // FIXME make sure rhs is sequentially stored + // FIXME not very good if lhs is real and rhs complex while alpha is real too + // FIXME we should probably build an evaluator for dst and lhs + const Index rows = dst.rows(); + for (Index i=0; i<rows; ++i) + func(dst.row(i), lhs.coeff(i) * rhs); +} template<typename Lhs, typename Rhs> struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,OuterProduct> { + template<typename T> struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {}; typedef typename Product<Lhs,Rhs>::Scalar Scalar; + // TODO it would be nice to be able to exploit our *_assign_op functors for that purpose + struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } }; + struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } }; + struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } }; + struct adds { + Scalar m_scale; + adds(const Scalar& s) : m_scale(s) {} + template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { + dst.const_cast_derived() += m_scale * src; + } + }; + template<typename Dst> static inline void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) { - // TODO bypass GeneralProduct class - GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).evalTo(dst); + internal::outer_product_selector_run(dst, lhs, rhs, set(), IsRowMajor<Dst>()); } template<typename Dst> static inline void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) { - // TODO bypass GeneralProduct class - GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).addTo(dst); + internal::outer_product_selector_run(dst, lhs, rhs, add(), IsRowMajor<Dst>()); } template<typename Dst> static inline void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) { - // TODO bypass GeneralProduct class - GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).subTo(dst); + internal::outer_product_selector_run(dst, lhs, rhs, sub(), IsRowMajor<Dst>()); } template<typename Dst> static inline void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) { - // TODO bypass GeneralProduct class - GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).scaleAndAddTo(dst, alpha); + internal::outer_product_selector_run(dst, lhs, rhs, adds(alpha), IsRowMajor<Dst>()); } }; |