aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/GeneralProduct.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2013-04-19 11:21:39 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2013-04-19 11:21:39 +0200
commit9cd2d14005def8e7df0b0bf5fd6eb51f8a6591e9 (patch)
treeca4df13b58e923bdebd9d5f59aecda9d1e30ca58 /Eigen/src/Core/GeneralProduct.h
parent4e2e615a7c2c719d2d708ab32840bad353322d8c (diff)
parent46755648ec341aa5e0283b47456108bb2897b1b3 (diff)
merge with default branch
Diffstat (limited to 'Eigen/src/Core/GeneralProduct.h')
-rw-r--r--Eigen/src/Core/GeneralProduct.h96
1 files changed, 59 insertions, 37 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h
index a070e618d..557286003 100644
--- a/Eigen/src/Core/GeneralProduct.h
+++ b/Eigen/src/Core/GeneralProduct.h
@@ -222,7 +222,29 @@ class GeneralProduct<Lhs, Rhs, InnerProduct>
***********************************************************************/
namespace internal {
-template<int StorageOrder> struct outer_product_selector;
+
+// Column major
+template<typename ProductType, typename Dest, typename Func>
+EIGEN_DONT_INLINE void outer_product_selector_run(const ProductType& prod, Dest& dest, const Func& func, const false_type&)
+{
+ typedef typename Dest::Index Index;
+ // FIXME make sure lhs is sequentially stored
+ // FIXME not very good if rhs is real and lhs complex while alpha is real too
+ const Index cols = dest.cols();
+ for (Index j=0; j<cols; ++j)
+ func(dest.col(j), prod.rhs().coeff(j) * prod.lhs());
+}
+
+// Row major
+template<typename ProductType, typename Dest, typename Func>
+EIGEN_DONT_INLINE void outer_product_selector_run(const ProductType& prod, Dest& dest, const Func& func, const true_type&) {
+ typedef typename Dest::Index Index;
+ // FIXME make sure rhs is sequentially stored
+ // FIXME not very good if lhs is real and rhs complex while alpha is real too
+ const Index rows = dest.rows();
+ for (Index i=0; i<rows; ++i)
+ func(dest.row(i), prod.lhs().coeff(i) * prod.rhs());
+}
template<typename Lhs, typename Rhs>
struct traits<GeneralProduct<Lhs,Rhs,OuterProduct> >
@@ -235,6 +257,8 @@ template<typename Lhs, typename Rhs>
class GeneralProduct<Lhs, Rhs, OuterProduct>
: public ProductBase<GeneralProduct<Lhs,Rhs,OuterProduct>, Lhs, Rhs>
{
+ template<typename T> struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {};
+
public:
EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct)
@@ -243,41 +267,39 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
EIGEN_STATIC_ASSERT((internal::is_same<typename Lhs::RealScalar, typename Rhs::RealScalar>::value),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
}
-
- template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
- {
- internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, alpha);
+
+ struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } };
+ struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } };
+ struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } };
+ struct adds {
+ Scalar m_scale;
+ adds(const Scalar& s) : m_scale(s) {}
+ template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const {
+ dst.const_cast_derived() += m_scale * src;
+ }
+ };
+
+ template<typename Dest>
+ inline void evalTo(Dest& dest) const {
+ internal::outer_product_selector_run(*this, dest, set(), IsRowMajor<Dest>());
+ }
+
+ template<typename Dest>
+ inline void addTo(Dest& dest) const {
+ internal::outer_product_selector_run(*this, dest, add(), IsRowMajor<Dest>());
}
-};
-
-namespace internal {
-template<> struct outer_product_selector<ColMajor> {
- template<typename ProductType, typename Dest>
- static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) {
- typedef typename Dest::Index Index;
- // FIXME make sure lhs is sequentially stored
- // FIXME not very good if rhs is real and lhs complex while alpha is real too
- const Index cols = dest.cols();
- for (Index j=0; j<cols; ++j)
- dest.col(j) += (alpha * prod.rhs().coeff(j)) * prod.lhs();
- }
-};
+ template<typename Dest>
+ inline void subTo(Dest& dest) const {
+ internal::outer_product_selector_run(*this, dest, sub(), IsRowMajor<Dest>());
+ }
-template<> struct outer_product_selector<RowMajor> {
- template<typename ProductType, typename Dest>
- static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) {
- typedef typename Dest::Index Index;
- // FIXME make sure rhs is sequentially stored
- // FIXME not very good if lhs is real and rhs complex while alpha is real too
- const Index rows = dest.rows();
- for (Index i=0; i<rows; ++i)
- dest.row(i) += (alpha * prod.lhs().coeff(i)) * prod.rhs();
- }
+ template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const
+ {
+ internal::outer_product_selector_run(*this, dest, adds(alpha), IsRowMajor<Dest>());
+ }
};
-} // end namespace internal
-
/***********************************************************************
* Implementation of General Matrix Vector Product
***********************************************************************/
@@ -320,7 +342,7 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight };
typedef typename internal::conditional<int(Side)==OnTheRight,_LhsNested,_RhsNested>::type MatrixType;
- template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
+ template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
{
eigen_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols());
internal::gemv_selector<Side,(int(MatrixType::Flags)&RowMajorBit) ? RowMajor : ColMajor,
@@ -335,7 +357,7 @@ template<int StorageOrder, bool BlasCompatible>
struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
{
template<typename ProductType, typename Dest>
- static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha)
+ static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
{
Transpose<Dest> destT(dest);
enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
@@ -384,7 +406,7 @@ struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
template<> struct gemv_selector<OnTheRight,ColMajor,true>
{
template<typename ProductType, typename Dest>
- static inline void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha)
+ static inline void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
{
typedef typename ProductType::Index Index;
typedef typename ProductType::LhsScalar LhsScalar;
@@ -457,7 +479,7 @@ template<> struct gemv_selector<OnTheRight,ColMajor,true>
template<> struct gemv_selector<OnTheRight,RowMajor,true>
{
template<typename ProductType, typename Dest>
- static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha)
+ static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
{
typedef typename ProductType::LhsScalar LhsScalar;
typedef typename ProductType::RhsScalar RhsScalar;
@@ -508,7 +530,7 @@ template<> struct gemv_selector<OnTheRight,RowMajor,true>
template<> struct gemv_selector<OnTheRight,ColMajor,false>
{
template<typename ProductType, typename Dest>
- static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha)
+ static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
{
typedef typename Dest::Index Index;
// TODO makes sure dest is sequentially stored in memory, otherwise use a temp
@@ -521,7 +543,7 @@ template<> struct gemv_selector<OnTheRight,ColMajor,false>
template<> struct gemv_selector<OnTheRight,RowMajor,false>
{
template<typename ProductType, typename Dest>
- static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha)
+ static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
{
typedef typename Dest::Index Index;
// TODO makes sure rhs is sequentially stored in memory, otherwise use a temp