aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-02-21 16:27:24 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-02-21 16:27:24 +0100
commit728c3d2cb955a255cae5515197ae65dc83209509 (patch)
tree04a0ee62f3a16432c58fb45d122bd4c1b60c60fb /Eigen/src/Core
parentaf31b6c37a3b4b32c8075d94b39a78108f12fd31 (diff)
Get rid of GeneralProduct for outer-products, and get rid of ScaledProduct
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r--Eigen/src/Core/GeneralProduct.h3
-rw-r--r--Eigen/src/Core/ProductBase.h3
-rw-r--r--Eigen/src/Core/ProductEvaluators.h71
3 files changed, 66 insertions, 11 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h
index f823ff251..4c0fc7f63 100644
--- a/Eigen/src/Core/GeneralProduct.h
+++ b/Eigen/src/Core/GeneralProduct.h
@@ -247,6 +247,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct>
* Implementation of Outer Vector Vector Product
***********************************************************************/
+#ifndef EIGEN_TEST_EVALUATORS
namespace internal {
// Column major
@@ -326,6 +327,8 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
}
};
+#endif // EIGEN_TEST_EVALUATORS
+
/***********************************************************************
* Implementation of General Matrix Vector Product
***********************************************************************/
diff --git a/Eigen/src/Core/ProductBase.h b/Eigen/src/Core/ProductBase.h
index a494b5f87..f6b719d19 100644
--- a/Eigen/src/Core/ProductBase.h
+++ b/Eigen/src/Core/ProductBase.h
@@ -174,6 +174,7 @@ class ProductBase : public MatrixBase<Derived>
mutable PlainObject m_result;
};
+#ifndef EIGEN_TEST_EVALUATORS
// here we need to overload the nested rule for products
// such that the nested type is a const reference to a plain matrix
namespace internal {
@@ -263,6 +264,8 @@ class ScaledProduct
Scalar m_alpha;
};
+#endif // EIGEN_TEST_EVALUATORS
+
/** \internal
* Overloaded to perform an efficient C = (A*B).lazy() */
template<typename Derived>
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index cf612d58a..93ae5f5f5 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -17,7 +17,11 @@ namespace Eigen {
namespace internal {
-// Like more general binary expressions, products need their own evaluator:
+/** \internal
+ * \class product_evaluator
+ * Products need their own evaluator with more template arguments allowing for
+ * easier partial template specializations.
+ */
template< typename T,
int ProductTag = internal::product_type<typename T::Lhs,typename T::Rhs>::ret,
typename LhsShape = typename evaluator_traits<typename T::Lhs>::Shape,
@@ -26,6 +30,14 @@ template< typename T,
typename RhsScalar = typename T::Rhs::Scalar
> struct product_evaluator;
+/** \internal
+ * Evaluator of a product expression.
+ * Since products require special treatments to handle all possible cases,
+ * we simply deffer the evaluation logic to a product_evaluator class
+ * which offers more partial specialization possibilities.
+ *
+ * \sa class product_evaluator
+ */
template<typename Lhs, typename Rhs, int Options>
struct evaluator<Product<Lhs, Rhs, Options> >
: public product_evaluator<Product<Lhs, Rhs, Options> >
@@ -40,7 +52,7 @@ struct evaluator<Product<Lhs, Rhs, Options> >
};
// Catch scalar * ( A * B ) and transform it to (A*scalar) * B
-// TODO we should apply that rule if that's really helpful
+// TODO we should apply that rule only if that's really helpful
template<typename Lhs, typename Rhs, typename Scalar>
struct evaluator<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > >
: public evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,const Lhs>, Rhs, DefaultProduct> >
@@ -66,7 +78,7 @@ struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> >
typedef evaluator type;
typedef evaluator nestedType;
-//
+
evaluator(const XprType& xpr)
: Base(Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>(
Product<Lhs, Rhs, LazyProduct>(xpr.nestedExpression().lhs(), xpr.nestedExpression().rhs()),
@@ -183,38 +195,75 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,InnerProduct>
};
+/***********************************************************************
+* Implementation of outer dense * dense vector product
+***********************************************************************/
+
+// Column major result
+template<typename Dst, typename Lhs, typename Rhs, typename Func>
+EIGEN_DONT_INLINE void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const false_type&)
+{
+ typedef typename Dst::Index Index;
+ // FIXME make sure lhs is sequentially stored
+ // FIXME not very good if rhs is real and lhs complex while alpha is real too
+ // FIXME we should probably build an evaluator for dst and rhs
+ const Index cols = dst.cols();
+ for (Index j=0; j<cols; ++j)
+ func(dst.col(j), rhs.coeff(j) * lhs);
+}
+
+// Row major result
+template<typename Dst, typename Lhs, typename Rhs, typename Func>
+EIGEN_DONT_INLINE void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const true_type&) {
+ typedef typename Dst::Index Index;
+ // FIXME make sure rhs is sequentially stored
+ // FIXME not very good if lhs is real and rhs complex while alpha is real too
+ // FIXME we should probably build an evaluator for dst and lhs
+ const Index rows = dst.rows();
+ for (Index i=0; i<rows; ++i)
+ func(dst.row(i), lhs.coeff(i) * rhs);
+}
template<typename Lhs, typename Rhs>
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,OuterProduct>
{
+ template<typename T> struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {};
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
+ // TODO it would be nice to be able to exploit our *_assign_op functors for that purpose
+ struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } };
+ struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } };
+ struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } };
+ struct adds {
+ Scalar m_scale;
+ adds(const Scalar& s) : m_scale(s) {}
+ template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const {
+ dst.const_cast_derived() += m_scale * src;
+ }
+ };
+
template<typename Dst>
static inline void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
- // TODO bypass GeneralProduct class
- GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).evalTo(dst);
+ internal::outer_product_selector_run(dst, lhs, rhs, set(), IsRowMajor<Dst>());
}
template<typename Dst>
static inline void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
- // TODO bypass GeneralProduct class
- GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).addTo(dst);
+ internal::outer_product_selector_run(dst, lhs, rhs, add(), IsRowMajor<Dst>());
}
template<typename Dst>
static inline void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
- // TODO bypass GeneralProduct class
- GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).subTo(dst);
+ internal::outer_product_selector_run(dst, lhs, rhs, sub(), IsRowMajor<Dst>());
}
template<typename Dst>
static inline void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{
- // TODO bypass GeneralProduct class
- GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).scaleAndAddTo(dst, alpha);
+ internal::outer_product_selector_run(dst, lhs, rhs, adds(alpha), IsRowMajor<Dst>());
}
};