diff options
author | Gael Guennebaud <g.gael@free.fr> | 2013-04-19 11:21:39 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2013-04-19 11:21:39 +0200 |
commit | 9cd2d14005def8e7df0b0bf5fd6eb51f8a6591e9 (patch) | |
tree | ca4df13b58e923bdebd9d5f59aecda9d1e30ca58 /Eigen/src/Core/GeneralProduct.h | |
parent | 4e2e615a7c2c719d2d708ab32840bad353322d8c (diff) | |
parent | 46755648ec341aa5e0283b47456108bb2897b1b3 (diff) |
merge with default branch
Diffstat (limited to 'Eigen/src/Core/GeneralProduct.h')
-rw-r--r-- | Eigen/src/Core/GeneralProduct.h | 96 |
1 files changed, 59 insertions, 37 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h index a070e618d..557286003 100644 --- a/Eigen/src/Core/GeneralProduct.h +++ b/Eigen/src/Core/GeneralProduct.h @@ -222,7 +222,29 @@ class GeneralProduct<Lhs, Rhs, InnerProduct> ***********************************************************************/ namespace internal { -template<int StorageOrder> struct outer_product_selector; + +// Column major +template<typename ProductType, typename Dest, typename Func> +EIGEN_DONT_INLINE void outer_product_selector_run(const ProductType& prod, Dest& dest, const Func& func, const false_type&) +{ + typedef typename Dest::Index Index; + // FIXME make sure lhs is sequentially stored + // FIXME not very good if rhs is real and lhs complex while alpha is real too + const Index cols = dest.cols(); + for (Index j=0; j<cols; ++j) + func(dest.col(j), prod.rhs().coeff(j) * prod.lhs()); +} + +// Row major +template<typename ProductType, typename Dest, typename Func> +EIGEN_DONT_INLINE void outer_product_selector_run(const ProductType& prod, Dest& dest, const Func& func, const true_type&) { + typedef typename Dest::Index Index; + // FIXME make sure rhs is sequentially stored + // FIXME not very good if lhs is real and rhs complex while alpha is real too + const Index rows = dest.rows(); + for (Index i=0; i<rows; ++i) + func(dest.row(i), prod.lhs().coeff(i) * prod.rhs()); +} template<typename Lhs, typename Rhs> struct traits<GeneralProduct<Lhs,Rhs,OuterProduct> > @@ -235,6 +257,8 @@ template<typename Lhs, typename Rhs> class GeneralProduct<Lhs, Rhs, OuterProduct> : public ProductBase<GeneralProduct<Lhs,Rhs,OuterProduct>, Lhs, Rhs> { + template<typename T> struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {}; + public: EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct) @@ -243,41 +267,39 @@ class GeneralProduct<Lhs, Rhs, OuterProduct> EIGEN_STATIC_ASSERT((internal::is_same<typename Lhs::RealScalar, typename Rhs::RealScalar>::value), YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) } - - template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const - { - internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, alpha); + + struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } }; + struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } }; + struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } }; + struct adds { + Scalar m_scale; + adds(const Scalar& s) : m_scale(s) {} + template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { + dst.const_cast_derived() += m_scale * src; + } + }; + + template<typename Dest> + inline void evalTo(Dest& dest) const { + internal::outer_product_selector_run(*this, dest, set(), IsRowMajor<Dest>()); + } + + template<typename Dest> + inline void addTo(Dest& dest) const { + internal::outer_product_selector_run(*this, dest, add(), IsRowMajor<Dest>()); } -}; - -namespace internal { -template<> struct outer_product_selector<ColMajor> { - template<typename ProductType, typename Dest> - static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) { - typedef typename Dest::Index Index; - // FIXME make sure lhs is sequentially stored - // FIXME not very good if rhs is real and lhs complex while alpha is real too - const Index cols = dest.cols(); - for (Index j=0; j<cols; ++j) - dest.col(j) += (alpha * prod.rhs().coeff(j)) * prod.lhs(); - } -}; + template<typename Dest> + inline void subTo(Dest& dest) const { + internal::outer_product_selector_run(*this, dest, sub(), IsRowMajor<Dest>()); + } -template<> struct outer_product_selector<RowMajor> { - template<typename ProductType, typename Dest> - static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) { - typedef typename Dest::Index Index; - // FIXME make sure rhs is sequentially stored - // FIXME not very good if lhs is real and rhs complex while alpha is real too - const Index rows = dest.rows(); - for (Index i=0; i<rows; ++i) - dest.row(i) += (alpha * prod.lhs().coeff(i)) * prod.rhs(); - } + template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const + { + internal::outer_product_selector_run(*this, dest, adds(alpha), IsRowMajor<Dest>()); + } }; -} // end namespace internal - /*********************************************************************** * Implementation of General Matrix Vector Product ***********************************************************************/ @@ -320,7 +342,7 @@ class GeneralProduct<Lhs, Rhs, GemvProduct> enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight }; typedef typename internal::conditional<int(Side)==OnTheRight,_LhsNested,_RhsNested>::type MatrixType; - template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const + template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const { eigen_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols()); internal::gemv_selector<Side,(int(MatrixType::Flags)&RowMajorBit) ? RowMajor : ColMajor, @@ -335,7 +357,7 @@ template<int StorageOrder, bool BlasCompatible> struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible> { template<typename ProductType, typename Dest> - static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) + static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha) { Transpose<Dest> destT(dest); enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor }; @@ -384,7 +406,7 @@ struct gemv_static_vector_if<Scalar,Size,MaxSize,true> template<> struct gemv_selector<OnTheRight,ColMajor,true> { template<typename ProductType, typename Dest> - static inline void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) + static inline void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha) { typedef typename ProductType::Index Index; typedef typename ProductType::LhsScalar LhsScalar; @@ -457,7 +479,7 @@ template<> struct gemv_selector<OnTheRight,ColMajor,true> template<> struct gemv_selector<OnTheRight,RowMajor,true> { template<typename ProductType, typename Dest> - static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) + static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha) { typedef typename ProductType::LhsScalar LhsScalar; typedef typename ProductType::RhsScalar RhsScalar; @@ -508,7 +530,7 @@ template<> struct gemv_selector<OnTheRight,RowMajor,true> template<> struct gemv_selector<OnTheRight,ColMajor,false> { template<typename ProductType, typename Dest> - static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) + static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha) { typedef typename Dest::Index Index; // TODO makes sure dest is sequentially stored in memory, otherwise use a temp @@ -521,7 +543,7 @@ template<> struct gemv_selector<OnTheRight,ColMajor,false> template<> struct gemv_selector<OnTheRight,RowMajor,false> { template<typename ProductType, typename Dest> - static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) + static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha) { typedef typename Dest::Index Index; // TODO makes sure rhs is sequentially stored in memory, otherwise use a temp |