From 13b2dafb5033a9de83c3dbd038b06c45845aeac1 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 7 Jul 2009 21:30:20 +0200 Subject: conjugate expressions are now properly caught by Product => significant speedup in expr. like a.adjoint() * b, for complex scalar type (~ x3) --- Eigen/src/Core/CwiseUnaryOp.h | 3 +- Eigen/src/Core/Product.h | 49 ++++++---- Eigen/src/Core/products/GeneralMatrixMatrix.h | 135 ++++++++++++++------------ 3 files changed, 107 insertions(+), 80 deletions(-) (limited to 'Eigen') diff --git a/Eigen/src/Core/CwiseUnaryOp.h b/Eigen/src/Core/CwiseUnaryOp.h index 0095a1572..3ffb24833 100644 --- a/Eigen/src/Core/CwiseUnaryOp.h +++ b/Eigen/src/Core/CwiseUnaryOp.h @@ -96,7 +96,8 @@ class CwiseUnaryOp : ei_no_assignment_operator, const UnaryOp& _functor() const { return m_functor; } /** \internal used for introspection */ - const typename MatrixType::Nested& _expression() const { return m_matrix; } + const typename ei_cleantype::type& + _expression() const { return m_matrix; } protected: const typename MatrixType::Nested m_matrix; diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 6849d90e3..a645ab6de 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -65,12 +65,11 @@ struct ProductReturnType template struct ProductReturnType { - typedef typename ei_nested::type LhsNested; - - typedef typename ei_nested::type LhsNested; + typedef typename ei_nested::type >::type RhsNested; - + typedef Product Type; }; @@ -95,14 +94,14 @@ template struct ei_product_mode template struct ei_product_factor_traits { typedef typename ei_traits::Scalar Scalar; - typedef XprType RealXprType; + typedef XprType ActualXprType; enum { IsComplex = NumTraits::IsComplex, NeedToConjugate = false, HasScalarMultiple = false, Access = int(ei_traits::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess }; - static inline const RealXprType& extract(const XprType& x) { return x; } + static inline const ActualXprType& extract(const XprType& x) { return x; } static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); } }; @@ -112,13 +111,13 @@ template struct ei_product_factor_traits Base; typedef CwiseUnaryOp, NestedXpr> XprType; - typedef typename Base::RealXprType RealXprType; + typedef typename Base::ActualXprType ActualXprType; enum { IsComplex = NumTraits::IsComplex, NeedToConjugate = IsComplex }; - static inline const RealXprType& extract(const XprType& x) { return x._expression(); } + static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); } }; @@ -128,12 +127,12 @@ template struct ei_product_factor_traits Base; typedef CwiseUnaryOp, NestedXpr> XprType; - typedef typename Base::RealXprType RealXprType; + typedef typename Base::ActualXprType ActualXprType; enum { HasScalarMultiple = true }; - static inline const RealXprType& extract(const XprType& x) { return x._expression(); } - static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().value; } + static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } + static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; } }; /** \class Product @@ -819,18 +818,34 @@ template template inline void Product::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const { - typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy; + typedef ei_product_factor_traits<_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<_RhsNested> RhsProductTraits; + + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; + + const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs); + const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs); + + Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs) + * RhsProductTraits::extractSalarFactor(m_rhs); + + typedef typename ei_product_copy_lhs::type LhsCopy; typedef typename ei_unref::type _LhsCopy; - typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy; + typedef typename ei_product_copy_rhs::type RhsCopy; typedef typename ei_unref::type _RhsCopy; - LhsCopy lhs(m_lhs); - RhsCopy rhs(m_rhs); - ei_cache_friendly_product( + LhsCopy lhs(actualLhs); + RhsCopy rhs(actualRhs); + ei_cache_friendly_product + ((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)), + ((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))> + ( rows(), cols(), lhs.cols(), _LhsCopy::Flags&RowMajorBit, (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(), _RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(), Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride(), - alpha + actualAlpha ); } diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 4630e5040..db63eadf9 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -58,6 +58,9 @@ template<> struct ei_conj_pmadd #ifndef EIGEN_EXTERN_INSTANTIATIONS +/** \warning you should never call this function directly, + * this is because the ConjugateLhs/ConjugateRhs have to + * be flipped is resRowMajor==true */ template static void ei_cache_friendly_product( int _rows, int _cols, int depth, @@ -76,6 +79,12 @@ static void ei_cache_friendly_product( if (resRowMajor) { +// return ei_cache_friendly_product(_cols,_rows,depth, +// !_rhsRowMajor, _rhs, _rhsStride, +// !_lhsRowMajor, _lhs, _lhsStride, +// false, res, resStride, +// alpha); + lhs = _rhs; rhs = _lhs; lhsStride = _rhsStride; @@ -252,59 +261,59 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); + C0 = cj_pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(B0, A1, C4); + C4 = cj_pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[6*PacketSize]); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); A0 = ei_pload(&blA[2*PacketSize]); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); A1 = ei_pload(&blA[3*PacketSize]); if(nr==4) B3 = ei_pload(&blB[7*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); - C4 = cj_pmadd(B0, A1, C4); + C0 = cj_pmadd(A0, B0, C0); + C4 = cj_pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[10*PacketSize]); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); A0 = ei_pload(&blA[4*PacketSize]); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); A1 = ei_pload(&blA[5*PacketSize]); if(nr==4) B3 = ei_pload(&blB[11*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); - C4 = cj_pmadd(B0, A1, C4); + C0 = cj_pmadd(A0, B0, C0); + C4 = cj_pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[14*PacketSize]); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); A0 = ei_pload(&blA[6*PacketSize]); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); A1 = ei_pload(&blA[7*PacketSize]); if(nr==4) B3 = ei_pload(&blB[15*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); - C4 = cj_pmadd(B0, A1, C4); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + C0 = cj_pmadd(A0, B0, C0); + C4 = cj_pmadd(A1, B0, C4); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); blB += 4*nr*PacketSize; blA += 4*mr; @@ -318,16 +327,16 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); + C0 = cj_pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(B0, A1, C4); + C4 = cj_pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); blB += nr*PacketSize; blA += mr; @@ -359,12 +368,12 @@ static void ei_cache_friendly_product( A0 = blA[k]; B0 = blB[0*PacketSize]; B1 = blB[1*PacketSize]; - C0 += B0 * A0; + C0 = cj_pmadd(A0, B0, C0); if(nr==4) B2 = blB[2*PacketSize]; if(nr==4) B3 = blB[3*PacketSize]; - C1 += B1 * A0; - if(nr==4) C2 += B2 * A0; - if(nr==4) C3 += B3 * A0; + C1 = cj_pmadd(A0, B1, C1); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); blB += nr*PacketSize; } @@ -382,10 +391,10 @@ static void ei_cache_friendly_product( Scalar c0 = Scalar(0); if (lhsRowMajor) for(int k=0; k