diff options
Diffstat (limited to 'Eigen/src/Core/Product.h')
-rw-r--r-- | Eigen/src/Core/Product.h | 49 |
1 files changed, 32 insertions, 17 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 6849d90e3..a645ab6de 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -65,12 +65,11 @@ struct ProductReturnType template<typename Lhs, typename Rhs> struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct> { - typedef typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type LhsNested; - - typedef typename ei_nested<Rhs,Lhs::RowsAtCompileTime, + typedef typename ei_nested<Lhs,1>::type LhsNested; + typedef typename ei_nested<Rhs,1, typename ei_plain_matrix_type_column_major<Rhs>::type >::type RhsNested; - + typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type; }; @@ -95,14 +94,14 @@ template<typename Lhs, typename Rhs> struct ei_product_mode template<typename XprType> struct ei_product_factor_traits { typedef typename ei_traits<XprType>::Scalar Scalar; - typedef XprType RealXprType; + typedef XprType ActualXprType; enum { IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = false, HasScalarMultiple = false, Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess }; - static inline const RealXprType& extract(const XprType& x) { return x; } + static inline const ActualXprType& extract(const XprType& x) { return x; } static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); } }; @@ -112,13 +111,13 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw { typedef ei_product_factor_traits<NestedXpr> Base; typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType; - typedef typename Base::RealXprType RealXprType; + typedef typename Base::ActualXprType ActualXprType; enum { IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = IsComplex }; - static inline const RealXprType& extract(const XprType& x) { return x._expression(); } + static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); } }; @@ -128,12 +127,12 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw { typedef ei_product_factor_traits<NestedXpr> Base; typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType; - typedef typename Base::RealXprType RealXprType; + typedef typename Base::ActualXprType ActualXprType; enum { HasScalarMultiple = true }; - static inline const RealXprType& extract(const XprType& x) { return x._expression(); } - static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().value; } + static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } + static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; } }; /** \class Product @@ -819,18 +818,34 @@ template<typename Lhs, typename Rhs, int ProductMode> template<typename DestDerived> inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const { - typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy; + typedef ei_product_factor_traits<_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<_RhsNested> RhsProductTraits; + + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; + + const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs); + const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs); + + Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs) + * RhsProductTraits::extractSalarFactor(m_rhs); + + typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy; typedef typename ei_unref<LhsCopy>::type _LhsCopy; - typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy; + typedef typename ei_product_copy_rhs<ActualRhsType>::type RhsCopy; typedef typename ei_unref<RhsCopy>::type _RhsCopy; - LhsCopy lhs(m_lhs); - RhsCopy rhs(m_rhs); - ei_cache_friendly_product<Scalar,false,false>( + LhsCopy lhs(actualLhs); + RhsCopy rhs(actualRhs); + ei_cache_friendly_product<Scalar, +// LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate> + ((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)), + ((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))> + ( rows(), cols(), lhs.cols(), _LhsCopy::Flags&RowMajorBit, (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(), _RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(), Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride(), - alpha + actualAlpha ); } |