diff options
author | Gael Guennebaud <g.gael@free.fr> | 2009-07-08 18:24:37 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2009-07-08 18:24:37 +0200 |
commit | 96e7d9f8969395db702775eaa0907b4aa941b2ba (patch) | |
tree | cf0789828bcf49b3ec4f4bf1ff28af1b2c30d0f1 /Eigen/src/Core | |
parent | 13b2dafb5033a9de83c3dbd038b06c45845aeac1 (diff) |
ok now all the complex mat-mat and mat-vec products involving conjugate,
adjoint, -, and scalar multiple seems to be well handled. It only remains
the simpler case: C = alpha*(A*B) ... for the next commit
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/Product.h | 185 | ||||
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 170 | ||||
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixVector.h | 101 |
3 files changed, 282 insertions, 174 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index a645ab6de..d63a7aa95 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -73,24 +73,9 @@ struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct> typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type; }; -/* Helper class to determine the type of the product, can be either: - * - NormalProduct - * - CacheFriendlyProduct - */ -template<typename Lhs, typename Rhs> struct ei_product_mode -{ - enum{ - - value = Lhs::MaxColsAtCompileTime == Dynamic - && ( Lhs::MaxRowsAtCompileTime == Dynamic - || Rhs::MaxColsAtCompileTime == Dynamic ) - && (!(Rhs::IsVectorAtCompileTime && (Lhs::Flags&RowMajorBit) && (!(Lhs::Flags&DirectAccessBit)))) - && (!(Lhs::IsVectorAtCompileTime && (!(Rhs::Flags&RowMajorBit)) && (!(Rhs::Flags&DirectAccessBit)))) - && (ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret) - ? CacheFriendlyProduct - : NormalProduct }; -}; - +/* Helper class to analyze the factors of a Product expression. + * In particular it allows to pop out operator-, scalar multiples, + * and conjugate */ template<typename XprType> struct ei_product_factor_traits { typedef typename ei_traits<XprType>::Scalar Scalar; @@ -98,11 +83,10 @@ template<typename XprType> struct ei_product_factor_traits enum { IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = false, - HasScalarMultiple = false, - Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess + ActualAccess = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess }; static inline const ActualXprType& extract(const XprType& x) { return x; } - static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); } + static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); } }; // pop conjugate @@ -117,8 +101,8 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = IsComplex }; - static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } - static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); } + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); } }; // pop scalar multiple @@ -128,11 +112,41 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw typedef ei_product_factor_traits<NestedXpr> Base; typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType; typedef typename Base::ActualXprType ActualXprType; - enum { - HasScalarMultiple = true - }; - static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } - static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; } + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) + { return x._functor().m_other * Base::extractScalarFactor(x._expression()); } +}; + +// pop opposite +template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> > + : ei_product_factor_traits<NestedXpr> +{ + typedef ei_product_factor_traits<NestedXpr> Base; + typedef CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> XprType; + typedef typename Base::ActualXprType ActualXprType; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) + { return - Base::extractScalarFactor(x._expression()); } +}; + +/* Helper class to determine the type of the product, can be either: + * - NormalProduct + * - CacheFriendlyProduct + */ +template<typename Lhs, typename Rhs> struct ei_product_mode +{ + typedef typename ei_product_factor_traits<Lhs>::ActualXprType ActualLhs; + typedef typename ei_product_factor_traits<Rhs>::ActualXprType ActualRhs; + enum{ + + value = Lhs::MaxColsAtCompileTime == Dynamic + && ( Lhs::MaxRowsAtCompileTime == Dynamic + || Rhs::MaxColsAtCompileTime == Dynamic ) + && (!(Rhs::IsVectorAtCompileTime && (Lhs::Flags&RowMajorBit) && (!(ActualLhs::Flags&DirectAccessBit)))) + && (!(Lhs::IsVectorAtCompileTime && (!(Rhs::Flags&RowMajorBit)) && (!(ActualRhs::Flags&DirectAccessBit)))) + && (ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret) + ? CacheFriendlyProduct + : NormalProduct }; }; /** \class Product @@ -552,11 +566,11 @@ void ei_cache_friendly_product( bool resRowMajor, Scalar* res, int resStride, Scalar alpha); -template<typename Scalar, typename RhsType> +template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename RhsType> static void ei_cache_friendly_product_colmajor_times_vector( int size, const Scalar* lhs, int lhsStride, const RhsType& rhs, Scalar* res, Scalar alpha); -template<typename Scalar, typename ResType> +template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename ResType> static void ei_cache_friendly_product_rowmajor_times_vector( const Scalar* lhs, int lhsStride, const Scalar* rhs, int rhsSize, ResType& res, Scalar alpha); @@ -572,10 +586,10 @@ static void ei_cache_friendly_product_rowmajor_times_vector( template<typename ProductType, int LhsRows = ei_traits<ProductType>::RowsAtCompileTime, int LhsOrder = int(ei_traits<ProductType>::LhsFlags)&RowMajorBit ? RowMajor : ColMajor, - int LhsHasDirectAccess = int(ei_traits<ProductType>::LhsFlags)&DirectAccessBit? HasDirectAccess : NoDirectAccess, + int LhsHasDirectAccess = ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested>::ActualAccess, int RhsCols = ei_traits<ProductType>::ColsAtCompileTime, int RhsOrder = int(ei_traits<ProductType>::RhsFlags)&RowMajorBit ? RowMajor : ColMajor, - int RhsHasDirectAccess = int(ei_traits<ProductType>::RhsFlags)&DirectAccessBit? HasDirectAccess : NoDirectAccess> + int RhsHasDirectAccess = ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested>::ActualAccess> struct ei_cache_friendly_product_selector { template<typename DestDerived> @@ -592,7 +606,6 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,NoDirectA template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { - // FIXME is it really used ? ei_assert(alpha==typename ProductType::Scalar(1)); const int size = product.rhs().rows(); for (int k=0; k<size; ++k) @@ -606,10 +619,21 @@ template<typename ProductType, int LhsRows, int RhsOrder, int RhsAccess> struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirectAccess,1,RhsOrder,RhsAccess> { typedef typename ProductType::Scalar Scalar; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits; + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; + template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { + const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs()); + const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs()); + + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs()) + * RhsProductTraits::extractScalarFactor(product.rhs()); + enum { EvalToRes = (ei_packet_traits<Scalar>::size==1) ||((DestDerived::Flags&ActualPacketAccessBit) && (!(DestDerived::Flags & RowMajorBit))) }; @@ -621,9 +645,12 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect _res = ei_aligned_stack_new(Scalar,res.size()); Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1> >(_res, res.size()) = res; } - ei_cache_friendly_product_colmajor_times_vector(res.size(), - &product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(), - product.rhs(), _res, alpha); + + ei_cache_friendly_product_colmajor_times_vector + <LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>( + res.size(), + &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.stride(), + actualRhs, _res, actualAlpha); if (!EvalToRes) { @@ -653,10 +680,21 @@ template<typename ProductType, int LhsOrder, int LhsAccess, int RhsCols> struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,RowMajor,HasDirectAccess> { typedef typename ProductType::Scalar Scalar; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits; + + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { + const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs()); + const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs()); + + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs()) + * RhsProductTraits::extractScalarFactor(product.rhs()); + enum { EvalToRes = (ei_packet_traits<Scalar>::size==1) ||((DestDerived::Flags & ActualPacketAccessBit) && (DestDerived::Flags & RowMajorBit)) }; @@ -668,9 +706,11 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo _res = ei_aligned_stack_new(Scalar, res.size()); Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()) = res; } - ei_cache_friendly_product_colmajor_times_vector(res.size(), - &product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(), - product.lhs().transpose(), _res, alpha); + + ei_cache_friendly_product_colmajor_times_vector + <RhsProductTraits::NeedToConjugate,LhsProductTraits::NeedToConjugate>(res.size(), + &actualRhs.const_cast_derived().coeffRef(0,0), actualRhs.stride(), + actualLhs.transpose(), _res, actualAlpha); if (!EvalToRes) { @@ -685,24 +725,39 @@ template<typename ProductType, int LhsRows, int RhsOrder, int RhsAccess> struct ei_cache_friendly_product_selector<ProductType,LhsRows,RowMajor,HasDirectAccess,1,RhsOrder,RhsAccess> { typedef typename ProductType::Scalar Scalar; - typedef typename ei_traits<ProductType>::_RhsNested Rhs; + + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits; + + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; + enum { - UseRhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (Rhs::Flags&ActualPacketAccessBit)) - && (!(Rhs::Flags & RowMajorBit)) }; + UseRhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (ActualRhsType::Flags&ActualPacketAccessBit)) + && (!(ActualRhsType::Flags & RowMajorBit)) }; template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { + const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs()); + const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs()); + + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs()) + * RhsProductTraits::extractScalarFactor(product.rhs()); + Scalar* EIGEN_RESTRICT _rhs; if (UseRhsDirectly) - _rhs = &product.rhs().const_cast_derived().coeffRef(0); + _rhs = &actualRhs.const_cast_derived().coeffRef(0); else { - _rhs = ei_aligned_stack_new(Scalar, product.rhs().size()); - Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1> >(_rhs, product.rhs().size()) = product.rhs(); + _rhs = ei_aligned_stack_new(Scalar, actualRhs.size()); + Map<Matrix<Scalar,ActualRhsType::SizeAtCompileTime,1> >(_rhs, actualRhs.size()) = actualRhs; } - ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(), - _rhs, product.rhs().size(), res, alpha); + + ei_cache_friendly_product_rowmajor_times_vector + <LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>( + &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.stride(), + _rhs, product.rhs().size(), res, actualAlpha); if (!UseRhsDirectly) ei_aligned_stack_delete(Scalar, _rhs, product.rhs().size()); } @@ -713,24 +768,39 @@ template<typename ProductType, int LhsOrder, int LhsAccess, int RhsCols> struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,ColMajor,HasDirectAccess> { typedef typename ProductType::Scalar Scalar; - typedef typename ei_traits<ProductType>::_LhsNested Lhs; + + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits; + + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; + enum { - UseLhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (Lhs::Flags&ActualPacketAccessBit)) - && (Lhs::Flags & RowMajorBit) }; + UseLhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (ActualLhsType::Flags&ActualPacketAccessBit)) + && (ActualLhsType::Flags & RowMajorBit) }; template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { + const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs()); + const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs()); + + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs()) + * RhsProductTraits::extractScalarFactor(product.rhs()); + Scalar* EIGEN_RESTRICT _lhs; if (UseLhsDirectly) - _lhs = &product.lhs().const_cast_derived().coeffRef(0); + _lhs = &actualLhs.const_cast_derived().coeffRef(0); else { - _lhs = ei_aligned_stack_new(Scalar, product.lhs().size()); - Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1> >(_lhs, product.lhs().size()) = product.lhs(); + _lhs = ei_aligned_stack_new(Scalar, actualLhs.size()); + Map<Matrix<Scalar,ActualLhsType::SizeAtCompileTime,1> >(_lhs, actualLhs.size()) = actualLhs; } - ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(), - _lhs, product.lhs().size(), res, alpha); + + ei_cache_friendly_product_rowmajor_times_vector + <RhsProductTraits::NeedToConjugate, LhsProductTraits::NeedToConjugate>( + &actualRhs.const_cast_derived().coeffRef(0,0), actualRhs.stride(), + _lhs, product.lhs().size(), res, actualAlpha); if(!UseLhsDirectly) ei_aligned_stack_delete(Scalar, _lhs, product.lhs().size()); } @@ -827,8 +897,8 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs); const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs); - Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs) - * RhsProductTraits::extractSalarFactor(m_rhs); + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(m_lhs) + * RhsProductTraits::extractScalarFactor(m_rhs); typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy; typedef typename ei_unref<LhsCopy>::type _LhsCopy; @@ -837,7 +907,6 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& LhsCopy lhs(actualLhs); RhsCopy rhs(actualRhs); ei_cache_friendly_product<Scalar, -// LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate> ((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)), ((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))> ( diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index db63eadf9..afd97b340 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -30,30 +30,46 @@ struct ei_L2_block_traits { enum {width = 8 * ei_meta_sqrt<L2MemorySize/(64*sizeof(Scalar))>::ret }; }; -template<bool ConjLhs, bool ConjRhs> struct ei_conj_pmadd; +template<bool ConjLhs, bool ConjRhs> struct ei_conj_helper; -template<> struct ei_conj_pmadd<false,false> +template<> struct ei_conj_helper<false,false> { template<typename T> - EIGEN_STRONG_INLINE T operator()(const T& x, const T& y, T& c) const { return ei_pmadd(x,y,c); } + EIGEN_STRONG_INLINE T pmadd(const T& x, const T& y, const T& c) const { return ei_pmadd(x,y,c); } + template<typename T> + EIGEN_STRONG_INLINE T pmul(const T& x, const T& y) const { return ei_pmul(x,y); } }; -template<> struct ei_conj_pmadd<false,true> +template<> struct ei_conj_helper<false,true> { - template<typename T> std::complex<T> operator()(const std::complex<T>& x, const std::complex<T>& y, std::complex<T>& c) const - { return c + std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_imag(x)*ei_real(y) - ei_real(x)*ei_imag(y)); } + template<typename T> std::complex<T> + pmadd(const std::complex<T>& x, const std::complex<T>& y, const std::complex<T>& c) const + { return c + pmul(x,y); } + + template<typename T> std::complex<T> pmul(const std::complex<T>& x, const std::complex<T>& y) const + //{ return std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_imag(x)*ei_real(y) - ei_real(x)*ei_imag(y)); } + { return x * ei_conj(y); } }; -template<> struct ei_conj_pmadd<true,false> +template<> struct ei_conj_helper<true,false> { - template<typename T> std::complex<T> operator()(const std::complex<T>& x, const std::complex<T>& y, std::complex<T>& c) const - { return c + std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } + template<typename T> std::complex<T> + pmadd(const std::complex<T>& x, const std::complex<T>& y, const std::complex<T>& c) const + { return c + pmul(x,y); } + + template<typename T> std::complex<T> pmul(const std::complex<T>& x, const std::complex<T>& y) const + { return std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } }; -template<> struct ei_conj_pmadd<true,true> +template<> struct ei_conj_helper<true,true> { - template<typename T> std::complex<T> operator()(const std::complex<T>& x, const std::complex<T>& y, std::complex<T>& c) const - { return c + std::complex<T>(ei_real(x)*ei_real(y) - ei_imag(x)*ei_imag(y), - ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } + template<typename T> std::complex<T> + pmadd(const std::complex<T>& x, const std::complex<T>& y, const std::complex<T>& c) const + { return c + pmul(x,y); } + + template<typename T> std::complex<T> pmul(const std::complex<T>& x, const std::complex<T>& y) const +// { return std::complex<T>(ei_real(x)*ei_real(y) - ei_imag(x)*ei_imag(y), - ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } + { return ei_conj(x) * ei_conj(y); } }; #ifndef EIGEN_EXTERN_INSTANTIATIONS @@ -74,7 +90,9 @@ static void ei_cache_friendly_product( int lhsStride, rhsStride, rows, cols; bool lhsRowMajor; - ei_conj_pmadd<ConjugateLhs,ConjugateRhs> cj_pmadd; + ei_conj_helper<ConjugateLhs,ConjugateRhs> cj; + if (ConjugateRhs) + alpha = ei_conj(alpha); bool hasAlpha = alpha != Scalar(1); if (resRowMajor) @@ -261,59 +279,59 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); + C0 = cj.pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(A1, B0, C4); + C4 = cj.pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[6*PacketSize]); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); A0 = ei_pload(&blA[2*PacketSize]); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); A1 = ei_pload(&blA[3*PacketSize]); if(nr==4) B3 = ei_pload(&blB[7*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); - C4 = cj_pmadd(A1, B0, C4); + C0 = cj.pmadd(A0, B0, C0); + C4 = cj.pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[10*PacketSize]); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); A0 = ei_pload(&blA[4*PacketSize]); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); A1 = ei_pload(&blA[5*PacketSize]); if(nr==4) B3 = ei_pload(&blB[11*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); - C4 = cj_pmadd(A1, B0, C4); + C0 = cj.pmadd(A0, B0, C0); + C4 = cj.pmadd(A1, B0, C4); B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); if(nr==4) B2 = ei_pload(&blB[14*PacketSize]); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); A0 = ei_pload(&blA[6*PacketSize]); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); A1 = ei_pload(&blA[7*PacketSize]); if(nr==4) B3 = ei_pload(&blB[15*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); - C4 = cj_pmadd(A1, B0, C4); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + C0 = cj.pmadd(A0, B0, C0); + C4 = cj.pmadd(A1, B0, C4); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); blB += 4*nr*PacketSize; blA += 4*mr; @@ -327,16 +345,16 @@ static void ei_cache_friendly_product( A1 = ei_pload(&blA[1*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]); - C0 = cj_pmadd(A0, B0, C0); + C0 = cj.pmadd(A0, B0, C0); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); - C4 = cj_pmadd(A1, B0, C4); + C4 = cj.pmadd(A1, B0, C4); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); - C1 = cj_pmadd(A0, B1, C1); - C5 = cj_pmadd(A1, B1, C5); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C6 = cj_pmadd(A1, B2, C6); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); - if(nr==4) C7 = cj_pmadd(A1, B3, C7); + C1 = cj.pmadd(A0, B1, C1); + C5 = cj.pmadd(A1, B1, C5); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C6 = cj.pmadd(A1, B2, C6); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); + if(nr==4) C7 = cj.pmadd(A1, B3, C7); blB += nr*PacketSize; blA += mr; @@ -368,12 +386,12 @@ static void ei_cache_friendly_product( A0 = blA[k]; B0 = blB[0*PacketSize]; B1 = blB[1*PacketSize]; - C0 = cj_pmadd(A0, B0, C0); + C0 = cj.pmadd(A0, B0, C0); if(nr==4) B2 = blB[2*PacketSize]; if(nr==4) B3 = blB[3*PacketSize]; - C1 = cj_pmadd(A0, B1, C1); - if(nr==4) C2 = cj_pmadd(A0, B2, C2); - if(nr==4) C3 = cj_pmadd(A0, B3, C3); + C1 = cj.pmadd(A0, B1, C1); + if(nr==4) C2 = cj.pmadd(A0, B2, C2); + if(nr==4) C3 = cj.pmadd(A0, B3, C3); blB += nr*PacketSize; } @@ -391,11 +409,11 @@ static void ei_cache_friendly_product( Scalar c0 = Scalar(0); if (lhsRowMajor) for(int k=0; k<actual_kc; k++) - c0 = cj_pmadd(lhs[(k2+k)+(i2+i)*lhsStride], rhs[j2*rhsStride + k2 + k], c0); + c0 += cj.pmul(lhs[(k2+k)+(i2+i)*lhsStride], rhs[j2*rhsStride + k2 + k]); else for(int k=0; k<actual_kc; k++) - c0 = cj_pmadd(lhs[(k2+k)*lhsStride + i2+i], rhs[j2*rhsStride + k2 + k], c0); - res[(j2)*resStride + i2+i] += alpha * c0; + c0 += cj.pmul(lhs[(k2+k)*lhsStride + i2+i], rhs[j2*rhsStride + k2 + k]); + res[(j2)*resStride + i2+i] += (ConjugateRhs ? ei_conj(alpha) : alpha) * c0; } } } @@ -493,39 +511,39 @@ static void ei_cache_friendly_product( L0 = ei_pload(&lb[1*PacketSize]); R1 = ei_pload(&lb[2*PacketSize]); L1 = ei_pload(&lb[3*PacketSize]); - T0 = cj_pmadd(A0, R0, T0); - T1 = cj_pmadd(A0, L0, T1); + T0 = cj.pmadd(A0, R0, T0); + T1 = cj.pmadd(A0, L0, T1); R0 = ei_pload(&lb[4*PacketSize]); L0 = ei_pload(&lb[5*PacketSize]); - T0 = cj_pmadd(A1, R1, T0); - T1 = cj_pmadd(A1, L1, T1); + T0 = cj.pmadd(A1, R1, T0); + T1 = cj.pmadd(A1, L1, T1); R1 = ei_pload(&lb[6*PacketSize]); L1 = ei_pload(&lb[7*PacketSize]); - T0 = cj_pmadd(A2, R0, T0); - T1 = cj_pmadd(A2, L0, T1); + T0 = cj.pmadd(A2, R0, T0); + T1 = cj.pmadd(A2, L0, T1); if(MaxBlockRows==8) { R0 = ei_pload(&lb[8*PacketSize]); L0 = ei_pload(&lb[9*PacketSize]); } - T0 = cj_pmadd(A3, R1, T0); - T1 = cj_pmadd(A3, L1, T1); + T0 = cj.pmadd(A3, R1, T0); + T1 = cj.pmadd(A3, L1, T1); if(MaxBlockRows==8) { R1 = ei_pload(&lb[10*PacketSize]); L1 = ei_pload(&lb[11*PacketSize]); - T0 = cj_pmadd(A4, R0, T0); - T1 = cj_pmadd(A4, L0, T1); + T0 = cj.pmadd(A4, R0, T0); + T1 = cj.pmadd(A4, L0, T1); R0 = ei_pload(&lb[12*PacketSize]); L0 = ei_pload(&lb[13*PacketSize]); - T0 = cj_pmadd(A5, R1, T0); - T1 = cj_pmadd(A5, L1, T1); + T0 = cj.pmadd(A5, R1, T0); + T1 = cj.pmadd(A5, L1, T1); R1 = ei_pload(&lb[14*PacketSize]); L1 = ei_pload(&lb[15*PacketSize]); - T0 = cj_pmadd(A6, R0, T0); - T1 = cj_pmadd(A6, L0, T1); - T0 = cj_pmadd(A7, R1, T0); - T1 = cj_pmadd(A7, L1, T1); + T0 = cj.pmadd(A6, R0, T0); + T1 = cj.pmadd(A6, L0, T1); + T0 = cj.pmadd(A7, R1, T0); + T1 = cj.pmadd(A7, L1, T1); } lb += MaxBlockRows*2*PacketSize; diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index 5cb0a9465..851bf808f 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h @@ -32,8 +32,9 @@ * same alignment pattern. * TODO: since rhs gets evaluated only once, no need to evaluate it */ -template<typename Scalar, typename RhsType> -static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( +template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename RhsType> +static EIGEN_DONT_INLINE +void ei_cache_friendly_product_colmajor_times_vector( int size, const Scalar* lhs, int lhsStride, const RhsType& rhs, @@ -47,10 +48,14 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( ei_pstore(&res[j], \ ei_padd(ei_pload(&res[j]), \ ei_padd( \ - ei_padd(ei_pmul(ptmp0,EIGEN_CAT(ei_ploa , A0)(&lhs0[j])), \ - ei_pmul(ptmp1,EIGEN_CAT(ei_ploa , A13)(&lhs1[j]))), \ - ei_padd(ei_pmul(ptmp2,EIGEN_CAT(ei_ploa , A2)(&lhs2[j])), \ - ei_pmul(ptmp3,EIGEN_CAT(ei_ploa , A13)(&lhs3[j]))) ))) + ei_padd(cj.pmul(EIGEN_CAT(ei_ploa , A0)(&lhs0[j]), ptmp0), \ + cj.pmul(EIGEN_CAT(ei_ploa , A13)(&lhs1[j]), ptmp1)), \ + ei_padd(cj.pmul(EIGEN_CAT(ei_ploa , A2)(&lhs2[j]), ptmp2), \ + cj.pmul(EIGEN_CAT(ei_ploa , A13)(&lhs3[j]), ptmp3)) ))) + + ei_conj_helper<ConjugateLhs,ConjugateRhs> cj; + if(ConjugateRhs) + alpha = ei_conj(alpha); typedef typename ei_packet_traits<Scalar>::type Packet; const int PacketSize = sizeof(Packet)/sizeof(Scalar); @@ -109,7 +114,7 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( ptmp2 = ei_pset1(alpha*rhs[i+2]), ptmp3 = ei_pset1(alpha*rhs[i+offset3]); // this helps a lot generating better binary code - const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride, + const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride, *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride; if (PacketSize>1) @@ -117,7 +122,13 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( /* explicit vectorization */ // process initial unaligned coeffs for (int j=0; j<alignedStart; ++j) - res[j] += ei_pfirst(ptmp0)*lhs0[j] + ei_pfirst(ptmp1)*lhs1[j] + ei_pfirst(ptmp2)*lhs2[j] + ei_pfirst(ptmp3)*lhs3[j]; + { + res[j] = cj.pmadd(lhs0[j], ei_pfirst(ptmp0), res[j]); + res[j] = cj.pmadd(lhs1[j], ei_pfirst(ptmp1), res[j]); + res[j] = cj.pmadd(lhs2[j], ei_pfirst(ptmp2), res[j]); + res[j] = cj.pmadd(lhs3[j], ei_pfirst(ptmp3), res[j]); +// res[j] += ei_pfirst(ptmp0)*lhs0[j] + ei_pfirst(ptmp1)*lhs1[j] + ei_pfirst(ptmp2)*lhs2[j] + ei_pfirst(ptmp3)*lhs3[j]; + } if (alignedSize>alignedStart) { @@ -148,19 +159,19 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( A00 = ei_pload (&lhs0[j]); A10 = ei_pload (&lhs0[j+PacketSize]); - A00 = ei_pmadd(ptmp0, A00, ei_pload(&res[j])); - A10 = ei_pmadd(ptmp0, A10, ei_pload(&res[j+PacketSize])); + A00 = cj.pmadd(A00, ptmp0, ei_pload(&res[j])); + A10 = cj.pmadd(A10, ptmp0, ei_pload(&res[j+PacketSize])); - A00 = ei_pmadd(ptmp1, A01, A00); + A00 = cj.pmadd(A01, ptmp1, A00); A01 = ei_pload(&lhs1[j-1+2*PacketSize]); ei_palign<1>(A11,A01); - A00 = ei_pmadd(ptmp2, A02, A00); + A00 = cj.pmadd(A02, ptmp2, A00); A02 = ei_pload(&lhs2[j-2+2*PacketSize]); ei_palign<2>(A12,A02); - A00 = ei_pmadd(ptmp3, A03, A00); + A00 = cj.pmadd(A03, ptmp3, A00); ei_pstore(&res[j],A00); A03 = ei_pload(&lhs3[j-3+2*PacketSize]); ei_palign<3>(A13,A03); - A10 = ei_pmadd(ptmp1, A11, A10); - A10 = ei_pmadd(ptmp2, A12, A10); - A10 = ei_pmadd(ptmp3, A13, A10); + A10 = cj.pmadd(A11, ptmp1, A10); + A10 = cj.pmadd(A12, ptmp2, A10); + A10 = cj.pmadd(A13, ptmp3, A10); ei_pstore(&res[j+PacketSize],A10); } } @@ -177,7 +188,13 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( /* process remaining coeffs (or all if there is no explicit vectorization) */ for (int j=alignedSize; j<size; ++j) - res[j] += ei_pfirst(ptmp0)*lhs0[j] + ei_pfirst(ptmp1)*lhs1[j] + ei_pfirst(ptmp2)*lhs2[j] + ei_pfirst(ptmp3)*lhs3[j]; + { + res[j] = cj.pmadd(lhs0[j], ei_pfirst(ptmp0), res[j]); + res[j] = cj.pmadd(lhs1[j], ei_pfirst(ptmp1), res[j]); + res[j] = cj.pmadd(lhs2[j], ei_pfirst(ptmp2), res[j]); + res[j] = cj.pmadd(lhs3[j], ei_pfirst(ptmp3), res[j]); +// res[j] += ei_pfirst(ptmp0)*lhs0[j] + ei_pfirst(ptmp1)*lhs1[j] + ei_pfirst(ptmp2)*lhs2[j] + ei_pfirst(ptmp3)*lhs3[j]; + } } // process remaining first and last columns (at most columnsAtOnce-1) @@ -195,20 +212,20 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( /* explicit vectorization */ // process first unaligned result's coeffs for (int j=0; j<alignedStart; ++j) - res[j] += ei_pfirst(ptmp0) * lhs0[j]; + res[j] = cj.pmul(lhs0[j], ei_pfirst(ptmp0)); // process aligned result's coeffs if ((size_t(lhs0+alignedStart)%sizeof(Packet))==0) for (int j = alignedStart;j<alignedSize;j+=PacketSize) - ei_pstore(&res[j], ei_pmadd(ptmp0,ei_pload(&lhs0[j]),ei_pload(&res[j]))); + ei_pstore(&res[j], cj.pmadd(ei_pload(&lhs0[j]), ptmp0, ei_pload(&res[j]))); else for (int j = alignedStart;j<alignedSize;j+=PacketSize) - ei_pstore(&res[j], ei_pmadd(ptmp0,ei_ploadu(&lhs0[j]),ei_pload(&res[j]))); + ei_pstore(&res[j], cj.pmadd(ei_ploadu(&lhs0[j]), ptmp0, ei_pload(&res[j]))); } // process remaining scalars (or all if no explicit vectorization) for (int j=alignedSize; j<size; ++j) - res[j] += ei_pfirst(ptmp0) * lhs0[j]; + res[j] += cj.pmul(lhs0[j], ei_pfirst(ptmp0)); } if (skipColumns) { @@ -223,7 +240,7 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( } // TODO add peeling to mask unaligned load/stores -template<typename Scalar, typename ResType> +template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename ResType> static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( const Scalar* lhs, int lhsStride, const Scalar* rhs, int rhsSize, @@ -236,10 +253,12 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( #define _EIGEN_ACCUMULATE_PACKETS(A0,A13,A2) {\ Packet b = ei_pload(&rhs[j]); \ - ptmp0 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A0) (&lhs0[j]), ptmp0); \ - ptmp1 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A13)(&lhs1[j]), ptmp1); \ - ptmp2 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A2) (&lhs2[j]), ptmp2); \ - ptmp3 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A13)(&lhs3[j]), ptmp3); } + ptmp0 = cj.pmadd(EIGEN_CAT(ei_ploa,A0) (&lhs0[j]), b, ptmp0); \ + ptmp1 = cj.pmadd(EIGEN_CAT(ei_ploa,A13)(&lhs1[j]), b, ptmp1); \ + ptmp2 = cj.pmadd(EIGEN_CAT(ei_ploa,A2) (&lhs2[j]), b, ptmp2); \ + ptmp3 = cj.pmadd(EIGEN_CAT(ei_ploa,A13)(&lhs3[j]), b, ptmp3); } + + ei_conj_helper<ConjugateLhs,ConjugateRhs> cj; typedef typename ei_packet_traits<Scalar>::type Packet; const int PacketSize = sizeof(Packet)/sizeof(Scalar); @@ -311,7 +330,8 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( for (int j=0; j<alignedStart; ++j) { Scalar b = rhs[j]; - tmp0 += b*lhs0[j]; tmp1 += b*lhs1[j]; tmp2 += b*lhs2[j]; tmp3 += b*lhs3[j]; + tmp0 += cj.pmul(lhs0[j],b); tmp1 += cj.pmul(lhs1[j],b); + tmp2 += cj.pmul(lhs2[j],b); tmp3 += cj.pmul(lhs3[j],b); } if (alignedSize>alignedStart) @@ -347,19 +367,19 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( A12 = ei_pload(&lhs2[j-2+PacketSize]); ei_palign<2>(A02,A12); A13 = ei_pload(&lhs3[j-3+PacketSize]); ei_palign<3>(A03,A13); - ptmp0 = ei_pmadd(b, ei_pload (&lhs0[j]), ptmp0); - ptmp1 = ei_pmadd(b, A01, ptmp1); + ptmp0 = cj.pmadd(ei_pload (&lhs0[j]), b, ptmp0); + ptmp1 = cj.pmadd(A01, b, ptmp1); A01 = ei_pload(&lhs1[j-1+2*PacketSize]); ei_palign<1>(A11,A01); - ptmp2 = ei_pmadd(b, A02, ptmp2); + ptmp2 = cj.pmadd(A02, b, ptmp2); A02 = ei_pload(&lhs2[j-2+2*PacketSize]); ei_palign<2>(A12,A02); - ptmp3 = ei_pmadd(b, A03, ptmp3); + ptmp3 = cj.pmadd(A03, b, ptmp3); A03 = ei_pload(&lhs3[j-3+2*PacketSize]); ei_palign<3>(A13,A03); b = ei_pload(&rhs[j+PacketSize]); - ptmp0 = ei_pmadd(b, ei_pload (&lhs0[j+PacketSize]), ptmp0); - ptmp1 = ei_pmadd(b, A11, ptmp1); - ptmp2 = ei_pmadd(b, A12, ptmp2); - ptmp3 = ei_pmadd(b, A13, ptmp3); + ptmp0 = cj.pmadd(ei_pload (&lhs0[j+PacketSize]), b, ptmp0); + ptmp1 = cj.pmadd(A11, b, ptmp1); + ptmp2 = cj.pmadd(A12, b, ptmp2); + ptmp3 = cj.pmadd(A13, b, ptmp3); } } for (int j = peeledSize; j<alignedSize; j+=PacketSize) @@ -382,7 +402,8 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( for (int j=alignedSize; j<size; ++j) { Scalar b = rhs[j]; - tmp0 += b*lhs0[j]; tmp1 += b*lhs1[j]; tmp2 += b*lhs2[j]; tmp3 += b*lhs3[j]; + tmp0 += cj.pmul(lhs0[j],b); tmp1 += cj.pmul(lhs1[j],b); + tmp2 += cj.pmul(lhs2[j],b); tmp3 += cj.pmul(lhs3[j],b); } res[i] += alpha*tmp0; res[i+offset1] += alpha*tmp1; res[i+2] += alpha*tmp2; res[i+offset3] += alpha*tmp3; } @@ -400,24 +421,24 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( // process first unaligned result's coeffs // FIXME this loop get vectorized by the compiler ! for (int j=0; j<alignedStart; ++j) - tmp0 += rhs[j] * lhs0[j]; + tmp0 += cj.pmul(lhs0[j], rhs[j]); if (alignedSize>alignedStart) { // process aligned rhs coeffs if ((size_t(lhs0+alignedStart)%sizeof(Packet))==0) for (int j = alignedStart;j<alignedSize;j+=PacketSize) - ptmp0 = ei_pmadd(ei_pload(&rhs[j]), ei_pload(&lhs0[j]), ptmp0); + ptmp0 = cj.pmadd(ei_pload(&lhs0[j]), ei_pload(&rhs[j]), ptmp0); else for (int j = alignedStart;j<alignedSize;j+=PacketSize) - ptmp0 = ei_pmadd(ei_pload(&rhs[j]), ei_ploadu(&lhs0[j]), ptmp0); + ptmp0 = cj.pmadd(ei_ploadu(&lhs0[j]), ei_pload(&rhs[j]), ptmp0); tmp0 += ei_predux(ptmp0); } // process remaining scalars // FIXME this loop get vectorized by the compiler ! for (int j=alignedSize; j<size; ++j) - tmp0 += rhs[j] * lhs0[j]; + tmp0 += cj.pmul(lhs0[j], rhs[j]); res[i] += alpha*tmp0; } if (skipRows) |