diff options
author | Gael Guennebaud <g.gael@free.fr> | 2009-07-07 21:30:20 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2009-07-07 21:30:20 +0200 |
commit | 13b2dafb5033a9de83c3dbd038b06c45845aeac1 (patch) | |
tree | d8d3d2905eb5a207635dbdbfe6111da943fdb3cc /Eigen | |
parent | 5ed6ce90d3d626e86127961f0845570223ac9c0b (diff) |
conjugate expressions are now properly caught by Product
=> significant speedup in expr. like a.adjoint() * b,
for complex scalar type (~ x3)
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/src/Core/CwiseUnaryOp.h | 3 | ||||
-rw-r--r-- | Eigen/src/Core/Product.h | 49 | ||||
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 135 |
3 files changed, 107 insertions, 80 deletions
diff --git a/Eigen/src/Core/CwiseUnaryOp.h b/Eigen/src/Core/CwiseUnaryOp.h index 0095a1572..3ffb24833 100644 --- a/Eigen/src/Core/CwiseUnaryOp.h +++ b/Eigen/src/Core/CwiseUnaryOp.h @@ -96,7 +96,8 @@ class CwiseUnaryOp : ei_no_assignment_operator, const UnaryOp& _functor() const { return m_functor; } /** \internal used for introspection */ - const typename MatrixType::Nested& _expression() const { return m_matrix; } + const typename ei_cleantype<typename MatrixType::Nested>::type& + _expression() const { return m_matrix; } protected: const typename MatrixType::Nested m_matrix; diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 6849d90e3..a645ab6de 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -65,12 +65,11 @@ struct ProductReturnType template<typename Lhs, typename Rhs> struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct> { - typedef typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type LhsNested; - - typedef typename ei_nested<Rhs,Lhs::RowsAtCompileTime, + typedef typename ei_nested<Lhs,1>::type LhsNested; + typedef typename ei_nested<Rhs,1, typename ei_plain_matrix_type_column_major<Rhs>::type >::type RhsNested; - + typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type; }; @@ -95,14 +94,14 @@ template<typename Lhs, typename Rhs> struct ei_product_mode template<typename XprType> struct ei_product_factor_traits { typedef typename ei_traits<XprType>::Scalar Scalar; - typedef XprType RealXprType; + typedef XprType ActualXprType; enum { IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = false, HasScalarMultiple = false, Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess }; - static inline const RealXprType& extract(const XprType& x) { return x; } + static inline const ActualXprType& extract(const XprType& x) { return x; } static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); } }; @@ -112,13 +111,13 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw { typedef ei_product_factor_traits<NestedXpr> Base; typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType; - typedef typename Base::RealXprType RealXprType; + typedef typename Base::ActualXprType ActualXprType; enum { IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = IsComplex }; - static inline const RealXprType& extract(const XprType& x) { return x._expression(); } + static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); } }; @@ -128,12 +127,12 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw { typedef ei_product_factor_traits<NestedXpr> Base; typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType; - typedef typename Base::RealXprType RealXprType; + typedef typename Base::ActualXprType ActualXprType; enum { HasScalarMultiple = true }; - static inline const RealXprType& extract(const XprType& x) { return x._expression(); } - static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().value; } + static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } + static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; } }; /** \class Product @@ -819,18 +818,34 @@ template<typename Lhs, typename Rhs, int ProductMode> template<typename DestDerived> inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const { - typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy; + typedef ei_product_factor_traits<_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<_RhsNested> RhsProductTraits; + + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; + + const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs); + const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs); + + Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs) + * RhsProductTraits::extractSalarFactor(m_rhs); + + typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy; typedef typename ei_unref<LhsCopy>::type _LhsCopy; - typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy; + typedef typename ei_product_copy_rhs<ActualRhsType>::type RhsCopy; typedef typename ei_unref<RhsCopy>::type _RhsCopy; - LhsCopy lhs(m_lhs); - RhsCopy rhs(m_rhs); - ei_cache_friendly_product<Scalar,false,false>( + LhsCopy lhs(actualLhs); + RhsCopy rhs(actualRhs); + ei_cache_friendly_product<Scalar, +// LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate> + ((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)), + ((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))> + ( rows(), cols(), lhs.cols(), _LhsCopy::Flags&RowMajorBit, (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(), _RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(), Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride(), - alpha + actualAlpha ); } diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 4630e5040..db63eadf9 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -58,6 +58,9 @@ template<> struct ei_conj_pmadd<true,true> #ifndef EIGEN_EXTERN_INSTANTIATIONS +/** \warning you should never call this function directly, + * this is because the ConjugateLhs/ConjugateRhs have to + * be flipped is resRowMajor==true */ template<typename Scalar, bool ConjugateLhs, bool ConjugateRhs> static void ei_cache_friendly_product( int _rows, int _cols, int depth, @@ -76,6 +79,12 @@ static void ei_cache_friendly_product( if (resRowMajor) { +// return ei_cache_friendly_product<Scalar,ConjugateRhs,ConjugateLhs>(_cols,_rows,depth, +// !_rhsRowMajor, _rhs, _rhsStride, +// !_lhsRowMajor, _lhs, _lhsStride, +// false, res, resStride, +// alpha); + lhs = _rhs; rhs = _lhs; lhsStride = _rhsStride; @@ -252,59 +261,59 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); + C0 = cj_pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(B0, A1, C4); + C4 = cj_pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[6*PacketSize]); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); A0 = ei_pload(&blA[2*PacketSize]); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); A1 = ei_pload(&blA[3*PacketSize]); if(nr==4) B3 = ei_pload(&blB[7*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); - C4 = cj_pmadd(B0, A1, C4); + C0 = cj_pmadd(A0, B0, C0); + C4 = cj_pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[10*PacketSize]); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); A0 = ei_pload(&blA[4*PacketSize]); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); A1 = ei_pload(&blA[5*PacketSize]); if(nr==4) B3 = ei_pload(&blB[11*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); - C4 = cj_pmadd(B0, A1, C4); + C0 = cj_pmadd(A0, B0, C0); + C4 = cj_pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[14*PacketSize]); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); A0 = ei_pload(&blA[6*PacketSize]); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); A1 = ei_pload(&blA[7*PacketSize]); if(nr==4) B3 = ei_pload(&blB[15*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); - C4 = cj_pmadd(B0, A1, C4); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + C0 = cj_pmadd(A0, B0, C0); + C4 = cj_pmadd(A1, B0, C4); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); blB += 4*nr*PacketSize; blA += 4*mr; @@ -318,16 +327,16 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(B0, A0, C0); + C0 = cj_pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(B0, A1, C4); + C4 = cj_pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); - C1 = cj_pmadd(B1, A0, C1); - C5 = cj_pmadd(B1, A1, C5); - if(nr==4) C2 = cj_pmadd(B2, A0, C2); - if(nr==4) C6 = cj_pmadd(B2, A1, C6); - if(nr==4) C3 = cj_pmadd(B3, A0, C3); - if(nr==4) C7 = cj_pmadd(B3, A1, C7); + C1 = cj_pmadd(A0, B1, C1); + C5 = cj_pmadd(A1, B1, C5); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C7 = cj_pmadd(A1, B3, C7); blB += nr*PacketSize; blA += mr; @@ -359,12 +368,12 @@ static void ei_cache_friendly_product( A0 = blA[k]; B0 = blB[0*PacketSize]; B1 = blB[1*PacketSize]; - C0 += B0 * A0; + C0 = cj_pmadd(A0, B0, C0); if(nr==4) B2 = blB[2*PacketSize]; if(nr==4) B3 = blB[3*PacketSize]; - C1 += B1 * A0; - if(nr==4) C2 += B2 * A0; - if(nr==4) C3 += B3 * A0; + C1 = cj_pmadd(A0, B1, C1); + if(nr==4) C2 = cj_pmadd(A0, B2, C2); + if(nr==4) C3 = cj_pmadd(A0, B3, C3); blB += nr*PacketSize; } @@ -382,10 +391,10 @@ static void ei_cache_friendly_product( Scalar c0 = Scalar(0); if (lhsRowMajor) for(int k=0; k<actual_kc; k++) - c0 += lhs[(k2+k)+(i2+i)*lhsStride] * rhs[j2*rhsStride + k2 + k]; + c0 = cj_pmadd(lhs[(k2+k)+(i2+i)*lhsStride], rhs[j2*rhsStride + k2 + k], c0); else for(int k=0; k<actual_kc; k++) - c0 += lhs[(k2+k)*lhsStride + i2+i] * rhs[j2*rhsStride + k2 + k]; + c0 = cj_pmadd(lhs[(k2+k)*lhsStride + i2+i], rhs[j2*rhsStride + k2 + k], c0); res[(j2)*resStride + i2+i] += alpha * c0; } } @@ -395,6 +404,8 @@ static void ei_cache_friendly_product( ei_aligned_stack_delete(Scalar, blockA, kc*mc); ei_aligned_stack_delete(Scalar, blockB, kc*cols*PacketSize); + + #else // alternate product from cylmor enum { @@ -482,39 +493,39 @@ static void ei_cache_friendly_product( L0 = ei_pload(&lb[1*PacketSize]); R1 = ei_pload(&lb[2*PacketSize]); L1 = ei_pload(&lb[3*PacketSize]); - T0 = cj_pmadd(R0, A0, T0); - T1 = cj_pmadd(L0, A0, T1); + T0 = cj_pmadd(A0, R0, T0); + T1 = cj_pmadd(A0, L0, T1); R0 = ei_pload(&lb[4*PacketSize]); L0 = ei_pload(&lb[5*PacketSize]); - T0 = cj_pmadd(R1, A1, T0); - T1 = cj_pmadd(L1, A1, T1); + T0 = cj_pmadd(A1, R1, T0); + T1 = cj_pmadd(A1, L1, T1); R1 = ei_pload(&lb[6*PacketSize]); L1 = ei_pload(&lb[7*PacketSize]); - T0 = cj_pmadd(R0, A2, T0); - T1 = cj_pmadd(L0, A2, T1); + T0 = cj_pmadd(A2, R0, T0); + T1 = cj_pmadd(A2, L0, T1); if(MaxBlockRows==8) { R0 = ei_pload(&lb[8*PacketSize]); L0 = ei_pload(&lb[9*PacketSize]); } - T0 = cj_pmadd(R1, A3, T0); - T1 = cj_pmadd(L1, A3, T1); + T0 = cj_pmadd(A3, R1, T0); + T1 = cj_pmadd(A3, L1, T1); if(MaxBlockRows==8) { R1 = ei_pload(&lb[10*PacketSize]); L1 = ei_pload(&lb[11*PacketSize]); - T0 = cj_pmadd(R0, A4, T0); - T1 = cj_pmadd(L0, A4, T1); + T0 = cj_pmadd(A4, R0, T0); + T1 = cj_pmadd(A4, L0, T1); R0 = ei_pload(&lb[12*PacketSize]); L0 = ei_pload(&lb[13*PacketSize]); - T0 = cj_pmadd(R1, A5, T0); - T1 = cj_pmadd(L1, A5, T1); + T0 = cj_pmadd(A5, R1, T0); + T1 = cj_pmadd(A5, L1, T1); R1 = ei_pload(&lb[14*PacketSize]); L1 = ei_pload(&lb[15*PacketSize]); - T0 = cj_pmadd(R0, A6, T0); - T1 = cj_pmadd(L0, A6, T1); - T0 = cj_pmadd(R1, A7, T0); - T1 = cj_pmadd(L1, A7, T1); + T0 = cj_pmadd(A6, R0, T0); + T1 = cj_pmadd(A6, L0, T1); + T0 = cj_pmadd(A7, R1, T0); + T1 = cj_pmadd(A7, L1, T1); } lb += MaxBlockRows*2*PacketSize; |