diff options
Diffstat (limited to 'Eigen/src/Core/products/GeneralBlockPanelKernel.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralBlockPanelKernel.h | 132 |
1 files changed, 17 insertions, 115 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index a96c7bfd4..10d132957 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -299,16 +299,6 @@ void computeProductBlockingSizes(Index& k, Index& m, Index& n, Index num_threads if (!useSpecificBlockingSizes(k, m, n)) { evaluateProductBlockingSizesHeuristic<LhsScalar, RhsScalar, KcFactor, Index>(k, m, n, num_threads); } - - typedef gebp_traits<LhsScalar,RhsScalar> Traits; - enum { - kr = 8, - mr = Traits::mr, - nr = Traits::nr - }; - if (k > kr) k -= k % kr; - if (m > mr) m -= m % mr; - if (n > nr) n -= n % nr; } template<typename LhsScalar, typename RhsScalar, typename Index> @@ -363,7 +353,7 @@ class gebp_traits public: typedef _LhsScalar LhsScalar; typedef _RhsScalar RhsScalar; - typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; + typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; enum { ConjLhs = _ConjLhs, @@ -444,15 +434,16 @@ public: template<typename LhsPacketType, typename RhsPacketType, typename AccPacketType> EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, AccPacketType& tmp) const { + conj_helper<LhsPacketType,RhsPacketType,ConjLhs,ConjRhs> cj; // It would be a lot cleaner to call pmadd all the time. Unfortunately if we // let gcc allocate the register in which to store the result of the pmul // (in the case where there is no FMA) gcc fails to figure out how to avoid // spilling register. #ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD EIGEN_UNUSED_VARIABLE(tmp); - c = pmadd(a,b,c); + c = cj.pmadd(a,b,c); #else - tmp = b; tmp = pmul(a,tmp); c = padd(c,tmp); + tmp = b; tmp = cj.pmul(a,tmp); c = padd(c,tmp); #endif } @@ -467,9 +458,6 @@ public: r = pmadd(c,alpha,r); } -protected: -// conj_helper<LhsScalar,RhsScalar,ConjLhs,ConjRhs> cj; -// conj_helper<LhsPacket,RhsPacket,ConjLhs,ConjRhs> pcj; }; template<typename RealScalar, bool _ConjLhs> @@ -478,7 +466,7 @@ class gebp_traits<std::complex<RealScalar>, RealScalar, _ConjLhs, false> public: typedef std::complex<RealScalar> LhsScalar; typedef RealScalar RhsScalar; - typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; + typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; enum { ConjLhs = _ConjLhs, @@ -860,80 +848,6 @@ protected: conj_helper<ResPacket,ResPacket,false,ConjRhs> cj; }; -// helper for the rotating kernel below -template <typename GebpKernel, bool UseRotatingKernel = GebpKernel::UseRotatingKernel> -struct PossiblyRotatingKernelHelper -{ - // default implementation, not rotating - - typedef typename GebpKernel::Traits Traits; - typedef typename Traits::RhsScalar RhsScalar; - typedef typename Traits::RhsPacket RhsPacket; - typedef typename Traits::AccPacket AccPacket; - - const Traits& traits; - PossiblyRotatingKernelHelper(const Traits& t) : traits(t) {} - - - template <size_t K, size_t Index> - void loadOrRotateRhs(RhsPacket& to, const RhsScalar* from) const - { - traits.loadRhs(from + (Index+4*K)*Traits::RhsProgress, to); - } - - void unrotateResult(AccPacket&, - AccPacket&, - AccPacket&, - AccPacket&) - { - } -}; - -// rotating implementation -template <typename GebpKernel> -struct PossiblyRotatingKernelHelper<GebpKernel, true> -{ - typedef typename GebpKernel::Traits Traits; - typedef typename Traits::RhsScalar RhsScalar; - typedef typename Traits::RhsPacket RhsPacket; - typedef typename Traits::AccPacket AccPacket; - - const Traits& traits; - PossiblyRotatingKernelHelper(const Traits& t) : traits(t) {} - - template <size_t K, size_t Index> - void loadOrRotateRhs(RhsPacket& to, const RhsScalar* from) const - { - if (Index == 0) { - to = pload<RhsPacket>(from + 4*K*Traits::RhsProgress); - } else { - EIGEN_ASM_COMMENT("Do not reorder code, we're very tight on registers"); - to = protate<1>(to); - } - } - - void unrotateResult(AccPacket& res0, - AccPacket& res1, - AccPacket& res2, - AccPacket& res3) - { - PacketBlock<AccPacket> resblock; - resblock.packet[0] = res0; - resblock.packet[1] = res1; - resblock.packet[2] = res2; - resblock.packet[3] = res3; - ptranspose(resblock); - resblock.packet[3] = protate<1>(resblock.packet[3]); - resblock.packet[2] = protate<2>(resblock.packet[2]); - resblock.packet[1] = protate<3>(resblock.packet[1]); - ptranspose(resblock); - res0 = resblock.packet[0]; - res1 = resblock.packet[1]; - res2 = resblock.packet[2]; - res3 = resblock.packet[3]; - } -}; - /* optimized GEneral packed Block * packed Panel product kernel * * Mixing type logic: C += A * B @@ -967,16 +881,6 @@ struct gebp_kernel ResPacketSize = Traits::ResPacketSize }; - - static const bool UseRotatingKernel = - EIGEN_ARCH_ARM && - internal::is_same<LhsScalar, float>::value && - internal::is_same<RhsScalar, float>::value && - internal::is_same<ResScalar, float>::value && - Traits::LhsPacketSize == 4 && - Traits::RhsPacketSize == 4 && - Traits::ResPacketSize == 4; - EIGEN_DONT_INLINE void operator()(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, Index rows, Index depth, Index cols, ResScalar alpha, @@ -1009,9 +913,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga // This corresponds to 3*LhsProgress x nr register blocks. // Usually, make sense only with FMA if(mr>=3*Traits::LhsProgress) - { - PossiblyRotatingKernelHelper<gebp_kernel> possiblyRotatingKernelHelper(traits); - + { // Here, the general idea is to loop on each largest micro horizontal panel of the lhs (3*Traits::LhsProgress x depth) // and on each largest micro vertical panel of the rhs (depth * nr). // Blocking sizes, i.e., 'depth' has been computed so that the micro horizontal panel of the lhs fit in L1. @@ -1074,19 +976,19 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga traits.loadLhs(&blA[(0+3*K)*LhsProgress], A0); \ traits.loadLhs(&blA[(1+3*K)*LhsProgress], A1); \ traits.loadLhs(&blA[(2+3*K)*LhsProgress], A2); \ - possiblyRotatingKernelHelper.template loadOrRotateRhs<K, 0>(B_0, blB); \ + traits.loadRhs(blB + (0+4*K)*Traits::RhsProgress, B_0); \ traits.madd(A0, B_0, C0, T0); \ traits.madd(A1, B_0, C4, T0); \ traits.madd(A2, B_0, C8, B_0); \ - possiblyRotatingKernelHelper.template loadOrRotateRhs<K, 1>(B_0, blB); \ + traits.loadRhs(blB + (1+4*K)*Traits::RhsProgress, B_0); \ traits.madd(A0, B_0, C1, T0); \ traits.madd(A1, B_0, C5, T0); \ traits.madd(A2, B_0, C9, B_0); \ - possiblyRotatingKernelHelper.template loadOrRotateRhs<K, 2>(B_0, blB); \ + traits.loadRhs(blB + (2+4*K)*Traits::RhsProgress, B_0); \ traits.madd(A0, B_0, C2, T0); \ traits.madd(A1, B_0, C6, T0); \ traits.madd(A2, B_0, C10, B_0); \ - possiblyRotatingKernelHelper.template loadOrRotateRhs<K, 3>(B_0, blB); \ + traits.loadRhs(blB + (3+4*K)*Traits::RhsProgress, B_0); \ traits.madd(A0, B_0, C3 , T0); \ traits.madd(A1, B_0, C7, T0); \ traits.madd(A2, B_0, C11, B_0); \ @@ -1120,10 +1022,6 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga #undef EIGEN_GEBP_ONESTEP - possiblyRotatingKernelHelper.unrotateResult(C0, C1, C2, C3); - possiblyRotatingKernelHelper.unrotateResult(C4, C5, C6, C7); - possiblyRotatingKernelHelper.unrotateResult(C8, C9, C10, C11); - ResPacket R0, R1, R2; ResPacket alphav = pset1<ResPacket>(alpha); @@ -1625,9 +1523,13 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga prefetch(&blA[0]); const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; - // NOTE The following piece of code doesn't work for 512 bit registers, - // so we don't call it for registers that contain more than 8 values. - if( ((SwappedTraits::LhsProgress % 4)==0) && (SwappedTraits::LhsProgress <= 8)) + // The following piece of code wont work for 512 bit registers + // Moreover, if LhsProgress==8 it assumes that there is a half packet of the same size + // as nr (which is currently 4) for the return type. + typedef typename unpacket_traits<SResPacket>::half SResPacketHalf; + if ((SwappedTraits::LhsProgress % 4) == 0 && + (SwappedTraits::LhsProgress <= 8) && + (SwappedTraits::LhsProgress!=8 || unpacket_traits<SResPacketHalf>::size==nr)) { SAccPacket C0, C1, C2, C3; straits.initAcc(C0); |