diff options
Diffstat (limited to 'Eigen/src/Core/products/GeneralBlockPanelKernel.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralBlockPanelKernel.h | 541 |
1 files changed, 292 insertions, 249 deletions
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index b5f06d831..11e5f591d 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -26,28 +26,37 @@ inline std::ptrdiff_t manage_caching_sizes_helper(std::ptrdiff_t a, std::ptrdiff } /** \internal */ -inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1=0, std::ptrdiff_t* l2=0) +inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1, std::ptrdiff_t* l2, std::ptrdiff_t* l3) { - static std::ptrdiff_t m_l1CacheSize = 0; - static std::ptrdiff_t m_l2CacheSize = 0; - if(m_l2CacheSize==0) + static bool m_cache_sizes_initialized = false; + static std::ptrdiff_t m_l1CacheSize = 32*1024; + static std::ptrdiff_t m_l2CacheSize = 256*1024; + static std::ptrdiff_t m_l3CacheSize = 2*1024*1024; + + if(!m_cache_sizes_initialized) { - m_l1CacheSize = manage_caching_sizes_helper(queryL1CacheSize(),8 * 1024); - m_l2CacheSize = manage_caching_sizes_helper(queryTopLevelCacheSize(),1*1024*1024); + int l1CacheSize, l2CacheSize, l3CacheSize; + queryCacheSizes(l1CacheSize, l2CacheSize, l3CacheSize); + m_l1CacheSize = manage_caching_sizes_helper(l1CacheSize, 8*1024); + m_l2CacheSize = manage_caching_sizes_helper(l2CacheSize, 256*1024); + m_l3CacheSize = manage_caching_sizes_helper(l3CacheSize, 8*1024*1024); + m_cache_sizes_initialized = true; } - + if(action==SetAction) { // set the cpu cache size and cache all block sizes from a global cache size in byte eigen_internal_assert(l1!=0 && l2!=0); m_l1CacheSize = *l1; m_l2CacheSize = *l2; + m_l3CacheSize = *l3; } else if(action==GetAction) { eigen_internal_assert(l1!=0 && l2!=0); *l1 = m_l1CacheSize; *l2 = m_l2CacheSize; + *l3 = m_l3CacheSize; } else { @@ -70,10 +79,11 @@ inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1=0, std::ptrdi * - the number of scalars that fit into a packet (when vectorization is enabled). * * \sa setCpuCacheSizes */ +#define CEIL(a, b) ((a)+(b)-1)/(b) + template<typename LhsScalar, typename RhsScalar, int KcFactor, typename SizeType> -void computeProductBlockingSizes(SizeType& k, SizeType& m, SizeType& n) +void computeProductBlockingSizes(SizeType& k, SizeType& m, SizeType& n, int num_threads) { - EIGEN_UNUSED_VARIABLE(n); // Explanations: // Let's recall the product algorithms form kc x nc horizontal panels B' on the rhs and // mc x kc blocks A' on the lhs. A' has to fit into L2 cache. Moreover, B' is processed @@ -81,43 +91,71 @@ void computeProductBlockingSizes(SizeType& k, SizeType& m, SizeType& n) // at the register level. For vectorization purpose, these small vertical panels are unpacked, // e.g., each coefficient is replicated to fit a packet. This small vertical panel has to // stay in L1 cache. - std::ptrdiff_t l1, l2; - - typedef gebp_traits<LhsScalar,RhsScalar> Traits; - enum { - kdiv = KcFactor * 2 * Traits::nr - * Traits::RhsProgress * sizeof(RhsScalar), - mr = gebp_traits<LhsScalar,RhsScalar>::mr, - mr_mask = (0xffffffff/mr)*mr - }; + std::ptrdiff_t l1, l2, l3; + manage_caching_sizes(GetAction, &l1, &l2, &l3); + + if (num_threads > 1) { + typedef gebp_traits<LhsScalar,RhsScalar> Traits; + typedef typename Traits::ResScalar ResScalar; + enum { + kdiv = KcFactor * (Traits::mr * sizeof(LhsScalar) + Traits::nr * sizeof(RhsScalar)), + ksub = Traits::mr * Traits::nr * sizeof(ResScalar), + k_mask = (0xffffffff/8)*8, + + mr = Traits::mr, + mr_mask = (0xffffffff/mr)*mr, + + nr = Traits::nr, + nr_mask = (0xffffffff/nr)*nr + }; + SizeType k_cache = (l1-ksub)/kdiv; + if (k_cache < k) { + k = k_cache & k_mask; + eigen_assert(k > 0); + } - manage_caching_sizes(GetAction, &l1, &l2); + SizeType n_cache = (l2-l1) / (nr * sizeof(RhsScalar) * k); + SizeType n_per_thread = CEIL(n, num_threads); + if (n_cache <= n_per_thread) { + // Don't exceed the capacity of the l2 cache. + eigen_assert(n_cache >= static_cast<SizeType>(nr)); + n = n_cache & nr_mask; + eigen_assert(n > 0); + } else { + n = (std::min<SizeType>)(n, (n_per_thread + nr - 1) & nr_mask); + } -// k = std::min<SizeType>(k, l1/kdiv); -// SizeType _m = k>0 ? l2/(4 * sizeof(LhsScalar) * k) : 0; -// if(_m<m) m = _m & mr_mask; - - // In unit tests we do not want to use extra large matrices, - // so we reduce the block size to check the blocking strategy is not flawed + if (l3 > l2) { + // l3 is shared between all cores, so we'll give each thread its own chunk of l3. + SizeType m_cache = (l3-l2) / (sizeof(LhsScalar) * k * num_threads); + SizeType m_per_thread = CEIL(m, num_threads); + if(m_cache < m_per_thread && m_cache >= static_cast<SizeType>(mr)) { + m = m_cache & mr_mask; + eigen_assert(m > 0); + } else { + m = (std::min<SizeType>)(m, (m_per_thread + mr - 1) & mr_mask); + } + } + } + else { + // In unit tests we do not want to use extra large matrices, + // so we reduce the block size to check the blocking strategy is not flawed #ifndef EIGEN_DEBUG_SMALL_PRODUCT_BLOCKS -// k = std::min<SizeType>(k,240); -// n = std::min<SizeType>(n,3840/sizeof(RhsScalar)); -// m = std::min<SizeType>(m,3840/sizeof(RhsScalar)); - - k = std::min<SizeType>(k,sizeof(LhsScalar)<=4 ? 360 : 240); - n = std::min<SizeType>(n,3840/sizeof(RhsScalar)); - m = std::min<SizeType>(m,3840/sizeof(RhsScalar)); + k = std::min<SizeType>(k,sizeof(LhsScalar)<=4 ? 360 : 240); + n = std::min<SizeType>(n,3840/sizeof(RhsScalar)); + m = std::min<SizeType>(m,3840/sizeof(RhsScalar)); #else - k = std::min<SizeType>(k,24); - n = std::min<SizeType>(n,384/sizeof(RhsScalar)); - m = std::min<SizeType>(m,384/sizeof(RhsScalar)); + k = std::min<SizeType>(k,24); + n = std::min<SizeType>(n,384/sizeof(RhsScalar)); + m = std::min<SizeType>(m,384/sizeof(RhsScalar)); #endif + } } template<typename LhsScalar, typename RhsScalar, typename SizeType> -inline void computeProductBlockingSizes(SizeType& k, SizeType& m, SizeType& n) +inline void computeProductBlockingSizes(SizeType& k, SizeType& m, SizeType& n, int num_threads) { - computeProductBlockingSizes<LhsScalar,RhsScalar,1>(k, m, n); + computeProductBlockingSizes<LhsScalar,RhsScalar,1>(k, m, n, num_threads); } #ifdef EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD @@ -667,7 +705,7 @@ protected: * |real |cplx | no vectorization yet, would require to pack A with duplication * |cplx |real | easy vectorization */ -template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> struct gebp_kernel { typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> Traits; @@ -676,14 +714,15 @@ struct gebp_kernel typedef typename Traits::RhsPacket RhsPacket; typedef typename Traits::ResPacket ResPacket; typedef typename Traits::AccPacket AccPacket; - + typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits; typedef typename SwappedTraits::ResScalar SResScalar; typedef typename SwappedTraits::LhsPacket SLhsPacket; typedef typename SwappedTraits::RhsPacket SRhsPacket; typedef typename SwappedTraits::ResPacket SResPacket; typedef typename SwappedTraits::AccPacket SAccPacket; - + + typedef typename DataMapper::LinearMapper LinearMapper; enum { Vectorizable = Traits::Vectorizable, @@ -693,14 +732,16 @@ struct gebp_kernel }; EIGEN_DONT_INLINE - void operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index rows, Index depth, Index cols, ResScalar alpha, + void operator()(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); }; -template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> +template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> EIGEN_DONT_INLINE -void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> - ::operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index rows, Index depth, Index cols, ResScalar alpha, +void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,ConjugateRhs> + ::operator()(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { Traits traits; @@ -743,15 +784,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> traits.initAcc(C4); traits.initAcc(C5); traits.initAcc(C6); traits.initAcc(C7); traits.initAcc(C8); traits.initAcc(C9); traits.initAcc(C10); traits.initAcc(C11); - ResScalar* r0 = &res[(j2+0)*resStride + i]; - ResScalar* r1 = &res[(j2+1)*resStride + i]; - ResScalar* r2 = &res[(j2+2)*resStride + i]; - ResScalar* r3 = &res[(j2+3)*resStride + i]; - - internal::prefetch(r0); - internal::prefetch(r1); - internal::prefetch(r2); - internal::prefetch(r3); + LinearMapper r0 = res.getLinearMapper(i, j2 + 0); + LinearMapper r1 = res.getLinearMapper(i, j2 + 1); + LinearMapper r2 = res.getLinearMapper(i, j2 + 2); + LinearMapper r3 = res.getLinearMapper(i, j2 + 3); + + r0.prefetch(0); + r1.prefetch(0); + r2.prefetch(0); + r3.prefetch(0); // performs "inner" products const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; @@ -820,48 +861,48 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> ResPacket R0, R1, R2; ResPacket alphav = pset1<ResPacket>(alpha); - - R0 = ploadu<ResPacket>(r0+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r0+1*Traits::ResPacketSize); - R2 = ploadu<ResPacket>(r0+2*Traits::ResPacketSize); + + R0 = r0.loadPacket(0 * Traits::ResPacketSize); + R1 = r0.loadPacket(1 * Traits::ResPacketSize); + R2 = r0.loadPacket(2 * Traits::ResPacketSize); traits.acc(C0, alphav, R0); traits.acc(C4, alphav, R1); traits.acc(C8, alphav, R2); - pstoreu(r0+0*Traits::ResPacketSize, R0); - pstoreu(r0+1*Traits::ResPacketSize, R1); - pstoreu(r0+2*Traits::ResPacketSize, R2); - - R0 = ploadu<ResPacket>(r1+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r1+1*Traits::ResPacketSize); - R2 = ploadu<ResPacket>(r1+2*Traits::ResPacketSize); + r0.storePacket(0 * Traits::ResPacketSize, R0); + r0.storePacket(1 * Traits::ResPacketSize, R1); + r0.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r1.loadPacket(0 * Traits::ResPacketSize); + R1 = r1.loadPacket(1 * Traits::ResPacketSize); + R2 = r1.loadPacket(2 * Traits::ResPacketSize); traits.acc(C1, alphav, R0); traits.acc(C5, alphav, R1); traits.acc(C9, alphav, R2); - pstoreu(r1+0*Traits::ResPacketSize, R0); - pstoreu(r1+1*Traits::ResPacketSize, R1); - pstoreu(r1+2*Traits::ResPacketSize, R2); - - R0 = ploadu<ResPacket>(r2+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r2+1*Traits::ResPacketSize); - R2 = ploadu<ResPacket>(r2+2*Traits::ResPacketSize); + r1.storePacket(0 * Traits::ResPacketSize, R0); + r1.storePacket(1 * Traits::ResPacketSize, R1); + r1.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r2.loadPacket(0 * Traits::ResPacketSize); + R1 = r2.loadPacket(1 * Traits::ResPacketSize); + R2 = r2.loadPacket(2 * Traits::ResPacketSize); traits.acc(C2, alphav, R0); traits.acc(C6, alphav, R1); traits.acc(C10, alphav, R2); - pstoreu(r2+0*Traits::ResPacketSize, R0); - pstoreu(r2+1*Traits::ResPacketSize, R1); - pstoreu(r2+2*Traits::ResPacketSize, R2); - - R0 = ploadu<ResPacket>(r3+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r3+1*Traits::ResPacketSize); - R2 = ploadu<ResPacket>(r3+2*Traits::ResPacketSize); + r2.storePacket(0 * Traits::ResPacketSize, R0); + r2.storePacket(1 * Traits::ResPacketSize, R1); + r2.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r3.loadPacket(0 * Traits::ResPacketSize); + R1 = r3.loadPacket(1 * Traits::ResPacketSize); + R2 = r3.loadPacket(2 * Traits::ResPacketSize); traits.acc(C3, alphav, R0); traits.acc(C7, alphav, R1); traits.acc(C11, alphav, R2); - pstoreu(r3+0*Traits::ResPacketSize, R0); - pstoreu(r3+1*Traits::ResPacketSize, R1); - pstoreu(r3+2*Traits::ResPacketSize, R2); + r3.storePacket(0 * Traits::ResPacketSize, R0); + r3.storePacket(1 * Traits::ResPacketSize, R1); + r3.storePacket(2 * Traits::ResPacketSize, R2); } - + // Deal with remaining columns of the rhs for(Index j2=packet_cols4; j2<cols; j2++) { @@ -875,7 +916,8 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> traits.initAcc(C4); traits.initAcc(C8); - ResScalar* r0 = &res[(j2+0)*resStride + i]; + LinearMapper r0 = res.getLinearMapper(i, j2); + r0.prefetch(0); // performs "inner" products const RhsScalar* blB = &blockB[j2*strideB+offsetB]; @@ -926,19 +968,19 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> ResPacket R0, R1, R2; ResPacket alphav = pset1<ResPacket>(alpha); - R0 = ploadu<ResPacket>(r0+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r0+1*Traits::ResPacketSize); - R2 = ploadu<ResPacket>(r0+2*Traits::ResPacketSize); + R0 = r0.loadPacket(0 * Traits::ResPacketSize); + R1 = r0.loadPacket(1 * Traits::ResPacketSize); + R2 = r0.loadPacket(2 * Traits::ResPacketSize); traits.acc(C0, alphav, R0); traits.acc(C4, alphav, R1); - traits.acc(C8 , alphav, R2); - pstoreu(r0+0*Traits::ResPacketSize, R0); - pstoreu(r0+1*Traits::ResPacketSize, R1); - pstoreu(r0+2*Traits::ResPacketSize, R2); + traits.acc(C8, alphav, R2); + r0.storePacket(0 * Traits::ResPacketSize, R0); + r0.storePacket(1 * Traits::ResPacketSize, R1); + r0.storePacket(2 * Traits::ResPacketSize, R2); } } } - + //---------- Process 2 * LhsProgress rows at once ---------- if(mr>=2*Traits::LhsProgress) { @@ -960,15 +1002,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> traits.initAcc(C0); traits.initAcc(C1); traits.initAcc(C2); traits.initAcc(C3); traits.initAcc(C4); traits.initAcc(C5); traits.initAcc(C6); traits.initAcc(C7); - ResScalar* r0 = &res[(j2+0)*resStride + i]; - ResScalar* r1 = &res[(j2+1)*resStride + i]; - ResScalar* r2 = &res[(j2+2)*resStride + i]; - ResScalar* r3 = &res[(j2+3)*resStride + i]; - - internal::prefetch(r0+prefetch_res_offset); - internal::prefetch(r1+prefetch_res_offset); - internal::prefetch(r2+prefetch_res_offset); - internal::prefetch(r3+prefetch_res_offset); + LinearMapper r0 = res.getLinearMapper(i, j2 + 0); + LinearMapper r1 = res.getLinearMapper(i, j2 + 1); + LinearMapper r2 = res.getLinearMapper(i, j2 + 2); + LinearMapper r3 = res.getLinearMapper(i, j2 + 3); + + r0.prefetch(prefetch_res_offset); + r1.prefetch(prefetch_res_offset); + r2.prefetch(prefetch_res_offset); + r3.prefetch(prefetch_res_offset); // performs "inner" products const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; @@ -1023,37 +1065,37 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> blA += 2*Traits::LhsProgress; } #undef EIGEN_GEBGP_ONESTEP - + ResPacket R0, R1, R2, R3; ResPacket alphav = pset1<ResPacket>(alpha); - - R0 = ploadu<ResPacket>(r0+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r0+1*Traits::ResPacketSize); - R2 = ploadu<ResPacket>(r1+0*Traits::ResPacketSize); - R3 = ploadu<ResPacket>(r1+1*Traits::ResPacketSize); + + R0 = r0.loadPacket(0 * Traits::ResPacketSize); + R1 = r0.loadPacket(1 * Traits::ResPacketSize); + R2 = r1.loadPacket(0 * Traits::ResPacketSize); + R3 = r1.loadPacket(1 * Traits::ResPacketSize); traits.acc(C0, alphav, R0); traits.acc(C4, alphav, R1); traits.acc(C1, alphav, R2); traits.acc(C5, alphav, R3); - pstoreu(r0+0*Traits::ResPacketSize, R0); - pstoreu(r0+1*Traits::ResPacketSize, R1); - pstoreu(r1+0*Traits::ResPacketSize, R2); - pstoreu(r1+1*Traits::ResPacketSize, R3); - - R0 = ploadu<ResPacket>(r2+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r2+1*Traits::ResPacketSize); - R2 = ploadu<ResPacket>(r3+0*Traits::ResPacketSize); - R3 = ploadu<ResPacket>(r3+1*Traits::ResPacketSize); + r0.storePacket(0 * Traits::ResPacketSize, R0); + r0.storePacket(1 * Traits::ResPacketSize, R1); + r1.storePacket(0 * Traits::ResPacketSize, R2); + r1.storePacket(1 * Traits::ResPacketSize, R3); + + R0 = r2.loadPacket(0 * Traits::ResPacketSize); + R1 = r2.loadPacket(1 * Traits::ResPacketSize); + R2 = r3.loadPacket(0 * Traits::ResPacketSize); + R3 = r3.loadPacket(1 * Traits::ResPacketSize); traits.acc(C2, alphav, R0); traits.acc(C6, alphav, R1); traits.acc(C3, alphav, R2); traits.acc(C7, alphav, R3); - pstoreu(r2+0*Traits::ResPacketSize, R0); - pstoreu(r2+1*Traits::ResPacketSize, R1); - pstoreu(r3+0*Traits::ResPacketSize, R2); - pstoreu(r3+1*Traits::ResPacketSize, R3); + r2.storePacket(0 * Traits::ResPacketSize, R0); + r2.storePacket(1 * Traits::ResPacketSize, R1); + r3.storePacket(0 * Traits::ResPacketSize, R2); + r3.storePacket(1 * Traits::ResPacketSize, R3); } - + // Deal with remaining columns of the rhs for(Index j2=packet_cols4; j2<cols; j2++) { @@ -1066,8 +1108,8 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> traits.initAcc(C0); traits.initAcc(C4); - ResScalar* r0 = &res[(j2+0)*resStride + i]; - internal::prefetch(r0+prefetch_res_offset); + LinearMapper r0 = res.getLinearMapper(i, j2); + r0.prefetch(prefetch_res_offset); // performs "inner" products const RhsScalar* blB = &blockB[j2*strideB+offsetB]; @@ -1117,12 +1159,12 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> ResPacket R0, R1; ResPacket alphav = pset1<ResPacket>(alpha); - R0 = ploadu<ResPacket>(r0+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r0+1*Traits::ResPacketSize); + R0 = r0.loadPacket(0 * Traits::ResPacketSize); + R1 = r0.loadPacket(1 * Traits::ResPacketSize); traits.acc(C0, alphav, R0); traits.acc(C4, alphav, R1); - pstoreu(r0+0*Traits::ResPacketSize, R0); - pstoreu(r0+1*Traits::ResPacketSize, R1); + r0.storePacket(0 * Traits::ResPacketSize, R0); + r0.storePacket(1 * Traits::ResPacketSize, R1); } } } @@ -1148,15 +1190,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> traits.initAcc(C2); traits.initAcc(C3); - ResScalar* r0 = &res[(j2+0)*resStride + i]; - ResScalar* r1 = &res[(j2+1)*resStride + i]; - ResScalar* r2 = &res[(j2+2)*resStride + i]; - ResScalar* r3 = &res[(j2+3)*resStride + i]; - - internal::prefetch(r0+prefetch_res_offset); - internal::prefetch(r1+prefetch_res_offset); - internal::prefetch(r2+prefetch_res_offset); - internal::prefetch(r3+prefetch_res_offset); + LinearMapper r0 = res.getLinearMapper(i, j2 + 0); + LinearMapper r1 = res.getLinearMapper(i, j2 + 1); + LinearMapper r2 = res.getLinearMapper(i, j2 + 2); + LinearMapper r3 = res.getLinearMapper(i, j2 + 3); + + r0.prefetch(prefetch_res_offset); + r1.prefetch(prefetch_res_offset); + r2.prefetch(prefetch_res_offset); + r3.prefetch(prefetch_res_offset); // performs "inner" products const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; @@ -1206,25 +1248,25 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> blA += 1*LhsProgress; } #undef EIGEN_GEBGP_ONESTEP - + ResPacket R0, R1; ResPacket alphav = pset1<ResPacket>(alpha); - - R0 = ploadu<ResPacket>(r0+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r1+0*Traits::ResPacketSize); + + R0 = r0.loadPacket(0 * Traits::ResPacketSize); + R1 = r1.loadPacket(0 * Traits::ResPacketSize); traits.acc(C0, alphav, R0); traits.acc(C1, alphav, R1); - pstoreu(r0+0*Traits::ResPacketSize, R0); - pstoreu(r1+0*Traits::ResPacketSize, R1); - - R0 = ploadu<ResPacket>(r2+0*Traits::ResPacketSize); - R1 = ploadu<ResPacket>(r3+0*Traits::ResPacketSize); + r0.storePacket(0 * Traits::ResPacketSize, R0); + r1.storePacket(0 * Traits::ResPacketSize, R1); + + R0 = r2.loadPacket(0 * Traits::ResPacketSize); + R1 = r3.loadPacket(0 * Traits::ResPacketSize); traits.acc(C2, alphav, R0); traits.acc(C3, alphav, R1); - pstoreu(r2+0*Traits::ResPacketSize, R0); - pstoreu(r3+0*Traits::ResPacketSize, R1); + r2.storePacket(0 * Traits::ResPacketSize, R0); + r3.storePacket(0 * Traits::ResPacketSize, R1); } - + // Deal with remaining columns of the rhs for(Index j2=packet_cols4; j2<cols; j2++) { @@ -1236,7 +1278,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> AccPacket C0; traits.initAcc(C0); - ResScalar* r0 = &res[(j2+0)*resStride + i]; + LinearMapper r0 = res.getLinearMapper(i, j2); // performs "inner" products const RhsScalar* blB = &blockB[j2*strideB+offsetB]; @@ -1283,9 +1325,9 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> #undef EIGEN_GEBGP_ONESTEP ResPacket R0; ResPacket alphav = pset1<ResPacket>(alpha); - R0 = ploadu<ResPacket>(r0+0*Traits::ResPacketSize); + R0 = r0.loadPacket(0 * Traits::ResPacketSize); traits.acc(C0, alphav, R0); - pstoreu(r0+0*Traits::ResPacketSize, R0); + r0.storePacket(0 * Traits::ResPacketSize, R0); } } } @@ -1301,7 +1343,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> const LhsScalar* blA = &blockA[i*strideA+offsetA]; prefetch(&blA[0]); const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; - + if( (SwappedTraits::LhsProgress % 4)==0 ) { // NOTE The following piece of code wont work for 512 bit registers @@ -1310,32 +1352,32 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> straits.initAcc(C1); straits.initAcc(C2); straits.initAcc(C3); - + const Index spk = (std::max)(1,SwappedTraits::LhsProgress/4); const Index endk = (depth/spk)*spk; const Index endk4 = (depth/(spk*4))*(spk*4); - + Index k=0; for(; k<endk4; k+=4*spk) { SLhsPacket A0,A1; SRhsPacket B_0,B_1; - + straits.loadLhsUnaligned(blB+0*SwappedTraits::LhsProgress, A0); straits.loadLhsUnaligned(blB+1*SwappedTraits::LhsProgress, A1); - + straits.loadRhsQuad(blA+0*spk, B_0); straits.loadRhsQuad(blA+1*spk, B_1); straits.madd(A0,B_0,C0,B_0); straits.madd(A1,B_1,C1,B_1); - + straits.loadLhsUnaligned(blB+2*SwappedTraits::LhsProgress, A0); straits.loadLhsUnaligned(blB+3*SwappedTraits::LhsProgress, A1); straits.loadRhsQuad(blA+2*spk, B_0); straits.loadRhsQuad(blA+3*spk, B_1); straits.madd(A0,B_0,C2,B_0); straits.madd(A1,B_1,C3,B_1); - + blB += 4*SwappedTraits::LhsProgress; blA += 4*spk; } @@ -1344,11 +1386,11 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> { SLhsPacket A0; SRhsPacket B_0; - + straits.loadLhsUnaligned(blB, A0); straits.loadRhsQuad(blA, B_0); straits.madd(A0,B_0,C0,B_0); - + blB += SwappedTraits::LhsProgress; blA += spk; } @@ -1359,10 +1401,10 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf; typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SLhsPacket>::half,SRhsPacket>::type SRhsPacketHalf; typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf; - - SResPacketHalf R = pgather<SResScalar, SResPacketHalf>(&res[j2*resStride + i], resStride); + + SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2); SResPacketHalf alphav = pset1<SResPacketHalf>(alpha); - + if(depth-endk>0) { // We have to handle the last row of the rhs which corresponds to a half-packet @@ -1378,14 +1420,14 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> { straits.acc(predux4(C0), alphav, R); } - pscatter(&res[j2*resStride + i], R, resStride); + res.scatterPacket(i, j2, R); } else { - SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride); + SResPacket R = res.template gatherPacket<SResPacket>(i, j2); SResPacket alphav = pset1<SResPacket>(alpha); straits.acc(C0, alphav, R); - pscatter(&res[j2*resStride + i], R, resStride); + res.scatterPacket(i, j2, R); } } else // scalar path @@ -1397,9 +1439,9 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> { LhsScalar A0; RhsScalar B_0, B_1; - + A0 = blA[k]; - + B_0 = blB[0]; B_1 = blB[1]; CJMADD(cj,A0,B_0,C0, B_0); @@ -1412,10 +1454,10 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> blB += 4; } - res[(j2+0)*resStride + i] += alpha*C0; - res[(j2+1)*resStride + i] += alpha*C1; - res[(j2+2)*resStride + i] += alpha*C2; - res[(j2+3)*resStride + i] += alpha*C3; + res(i, j2 + 0) += alpha * C0; + res(i, j2 + 1) += alpha * C1; + res(i, j2 + 2) += alpha * C2; + res(i, j2 + 3) += alpha * C3; } } } @@ -1436,7 +1478,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> RhsScalar B_0 = blB[k]; CJMADD(cj, A0, B_0, C0, B_0); } - res[(j2+0)*resStride + i] += alpha*C0; + res(i, j2) += alpha * C0; } } } @@ -1459,15 +1501,16 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> // // 32 33 34 35 ... // 36 36 38 39 ... -template<typename Scalar, typename Index, int Pack1, int Pack2, bool Conjugate, bool PanelMode> -struct gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conjugate, PanelMode> +template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode> +struct gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode> { - EIGEN_DONT_INLINE void operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows, Index stride=0, Index offset=0); + typedef typename DataMapper::LinearMapper LinearMapper; + EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); }; -template<typename Scalar, typename Index, int Pack1, int Pack2, bool Conjugate, bool PanelMode> -EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conjugate, PanelMode> - ::operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows, Index stride, Index offset) +template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode> +EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode> + ::operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { typedef typename packet_traits<Scalar>::type Packet; enum { PacketSize = packet_traits<Scalar>::size }; @@ -1478,30 +1521,29 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); eigen_assert( ((Pack1%PacketSize)==0 && Pack1<=4*PacketSize) || (Pack1<=4) ); conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; - const_blas_data_mapper<Scalar, Index, ColMajor> lhs(_lhs,lhsStride); Index count = 0; - + const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0; const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0; const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0; const Index peeled_mc0 = Pack2>=1*PacketSize ? peeled_mc1 : Pack2>1 ? (rows/Pack2)*Pack2 : 0; - + Index i=0; - + // Pack 3 packets if(Pack1>=3*PacketSize) { for(; i<peeled_mc3; i+=3*PacketSize) { if(PanelMode) count += (3*PacketSize) * offset; - + for(Index k=0; k<depth; k++) { Packet A, B, C; - A = ploadu<Packet>(&lhs(i+0*PacketSize, k)); - B = ploadu<Packet>(&lhs(i+1*PacketSize, k)); - C = ploadu<Packet>(&lhs(i+2*PacketSize, k)); + A = lhs.loadPacket(i+0*PacketSize, k); + B = lhs.loadPacket(i+1*PacketSize, k); + C = lhs.loadPacket(i+2*PacketSize, k); pstore(blockA+count, cj.pconj(A)); count+=PacketSize; pstore(blockA+count, cj.pconj(B)); count+=PacketSize; pstore(blockA+count, cj.pconj(C)); count+=PacketSize; @@ -1515,12 +1557,12 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj for(; i<peeled_mc2; i+=2*PacketSize) { if(PanelMode) count += (2*PacketSize) * offset; - + for(Index k=0; k<depth; k++) { Packet A, B; - A = ploadu<Packet>(&lhs(i+0*PacketSize, k)); - B = ploadu<Packet>(&lhs(i+1*PacketSize, k)); + A = lhs.loadPacket(i+0*PacketSize, k); + B = lhs.loadPacket(i+1*PacketSize, k); pstore(blockA+count, cj.pconj(A)); count+=PacketSize; pstore(blockA+count, cj.pconj(B)); count+=PacketSize; } @@ -1533,11 +1575,11 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj for(; i<peeled_mc1; i+=1*PacketSize) { if(PanelMode) count += (1*PacketSize) * offset; - + for(Index k=0; k<depth; k++) { Packet A; - A = ploadu<Packet>(&lhs(i+0*PacketSize, k)); + A = lhs.loadPacket(i+0*PacketSize, k); pstore(blockA+count, cj.pconj(A)); count+=PacketSize; } @@ -1550,11 +1592,11 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj for(; i<peeled_mc0; i+=Pack2) { if(PanelMode) count += Pack2 * offset; - + for(Index k=0; k<depth; k++) for(Index w=0; w<Pack2; w++) blockA[count++] = cj(lhs(i+w, k)); - + if(PanelMode) count += Pack2 * (stride-offset-depth); } } @@ -1567,15 +1609,16 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj } } -template<typename Scalar, typename Index, int Pack1, int Pack2, bool Conjugate, bool PanelMode> -struct gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conjugate, PanelMode> +template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode> +struct gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, RowMajor, Conjugate, PanelMode> { - EIGEN_DONT_INLINE void operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows, Index stride=0, Index offset=0); + typedef typename DataMapper::LinearMapper LinearMapper; + EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); }; -template<typename Scalar, typename Index, int Pack1, int Pack2, bool Conjugate, bool PanelMode> -EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conjugate, PanelMode> - ::operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows, Index stride, Index offset) +template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode> +EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, RowMajor, Conjugate, PanelMode> + ::operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { typedef typename packet_traits<Scalar>::type Packet; enum { PacketSize = packet_traits<Scalar>::size }; @@ -1585,13 +1628,12 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj EIGEN_UNUSED_VARIABLE(offset); eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; - const_blas_data_mapper<Scalar, Index, RowMajor> lhs(_lhs,lhsStride); Index count = 0; - + // const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0; // const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0; // const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0; - + int pack = Pack1; Index i = 0; while(pack>0) @@ -1611,7 +1653,7 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj for (Index m = 0; m < pack; m += PacketSize) { PacketBlock<Packet> kernel; - for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k)); + for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = lhs.loadPacket(i+p+m, k); ptranspose(kernel); for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel.packet[p])); } @@ -1636,15 +1678,15 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj for(;w<pack;++w) blockA[count++] = cj(lhs(i+w, k)); } - + if(PanelMode) count += pack * (stride-offset-depth); } - + pack -= PacketSize; if(pack<Pack2 && (pack+PacketSize)!=Pack2) pack = Pack2; } - + for(; i<rows; i++) { if(PanelMode) count += offset; @@ -1661,17 +1703,18 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj // 4 5 6 7 16 17 18 19 25 28 // 8 9 10 11 20 21 22 23 26 29 // . . . . . . . . . . -template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode> -struct gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, PanelMode> +template<typename Scalar, typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> +struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> { typedef typename packet_traits<Scalar>::type Packet; + typedef typename DataMapper::LinearMapper LinearMapper; enum { PacketSize = packet_traits<Scalar>::size }; - EIGEN_DONT_INLINE void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols, Index stride=0, Index offset=0); + EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); }; -template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode> -EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, PanelMode> - ::operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols, Index stride, Index offset) +template<typename Scalar, typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> +EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> + ::operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR"); EIGEN_UNUSED_VARIABLE(stride); @@ -1727,27 +1770,27 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, Pan // if(PanelMode) count += 8 * (stride-offset-depth); // } // } - + if(nr>=4) { for(Index j2=packet_cols8; j2<packet_cols4; j2+=4) { // skip what we have before if(PanelMode) count += 4 * offset; - const Scalar* b0 = &rhs[(j2+0)*rhsStride]; - const Scalar* b1 = &rhs[(j2+1)*rhsStride]; - const Scalar* b2 = &rhs[(j2+2)*rhsStride]; - const Scalar* b3 = &rhs[(j2+3)*rhsStride]; - + const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + Index k=0; if((PacketSize%4)==0) // TODO enbale vectorized transposition for PacketSize==2 ?? { for(; k<peeled_k; k+=PacketSize) { PacketBlock<Packet,(PacketSize%4)==0?4:PacketSize> kernel; - kernel.packet[0] = ploadu<Packet>(&b0[k]); - kernel.packet[1] = ploadu<Packet>(&b1[k]); - kernel.packet[2] = ploadu<Packet>(&b2[k]); - kernel.packet[3] = ploadu<Packet>(&b3[k]); + kernel.packet[0] = dm0.loadPacket(k); + kernel.packet[1] = dm1.loadPacket(k); + kernel.packet[2] = dm2.loadPacket(k); + kernel.packet[3] = dm3.loadPacket(k); ptranspose(kernel); pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0])); pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1])); @@ -1758,10 +1801,10 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, Pan } for(; k<depth; k++) { - blockB[count+0] = cj(b0[k]); - blockB[count+1] = cj(b1[k]); - blockB[count+2] = cj(b2[k]); - blockB[count+3] = cj(b3[k]); + blockB[count+0] = cj(dm0(k)); + blockB[count+1] = cj(dm1(k)); + blockB[count+2] = cj(dm2(k)); + blockB[count+3] = cj(dm3(k)); count += 4; } // skip what we have after @@ -1773,10 +1816,10 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, Pan for(Index j2=packet_cols4; j2<cols; ++j2) { if(PanelMode) count += offset; - const Scalar* b0 = &rhs[(j2+0)*rhsStride]; + const LinearMapper dm0 = rhs.getLinearMapper(0, j2); for(Index k=0; k<depth; k++) { - blockB[count] = cj(b0[k]); + blockB[count] = cj(dm0(k)); count += 1; } if(PanelMode) count += (stride-offset-depth); @@ -1784,17 +1827,18 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, Pan } // this version is optimized for row major matrices -template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode> -struct gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, PanelMode> +template<typename Scalar, typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> +struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> { typedef typename packet_traits<Scalar>::type Packet; + typedef typename DataMapper::LinearMapper LinearMapper; enum { PacketSize = packet_traits<Scalar>::size }; - EIGEN_DONT_INLINE void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols, Index stride=0, Index offset=0); + EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); }; -template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode> -EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, PanelMode> - ::operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols, Index stride, Index offset) +template<typename Scalar, typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> +EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> + ::operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR"); EIGEN_UNUSED_VARIABLE(stride); @@ -1804,7 +1848,7 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, Pan Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0; Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; Index count = 0; - + // if(nr>=8) // { // for(Index j2=0; j2<packet_cols8; j2+=8) @@ -1847,15 +1891,15 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, Pan for(Index k=0; k<depth; k++) { if (PacketSize==4) { - Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]); + Packet A = rhs.loadPacket(k, j2); pstoreu(blockB+count, cj.pconj(A)); count += PacketSize; } else { - const Scalar* b0 = &rhs[k*rhsStride + j2]; - blockB[count+0] = cj(b0[0]); - blockB[count+1] = cj(b0[1]); - blockB[count+2] = cj(b0[2]); - blockB[count+3] = cj(b0[3]); + const LinearMapper dm0 = rhs.getLinearMapper(k, j2); + blockB[count+0] = cj(dm0(0)); + blockB[count+1] = cj(dm0(1)); + blockB[count+2] = cj(dm0(2)); + blockB[count+3] = cj(dm0(3)); count += 4; } } @@ -1867,10 +1911,9 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, Pan for(Index j2=packet_cols4; j2<cols; ++j2) { if(PanelMode) count += offset; - const Scalar* b0 = &rhs[j2]; for(Index k=0; k<depth; k++) { - blockB[count] = cj(b0[k*rhsStride]); + blockB[count] = cj(rhs(k, j2)); count += 1; } if(PanelMode) count += stride-offset-depth; @@ -1883,8 +1926,8 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, Pan * \sa setCpuCacheSize */ inline std::ptrdiff_t l1CacheSize() { - std::ptrdiff_t l1, l2; - internal::manage_caching_sizes(GetAction, &l1, &l2); + std::ptrdiff_t l1, l2, l3; + internal::manage_caching_sizes(GetAction, &l1, &l2, &l3); return l1; } @@ -1892,8 +1935,8 @@ inline std::ptrdiff_t l1CacheSize() * \sa setCpuCacheSize */ inline std::ptrdiff_t l2CacheSize() { - std::ptrdiff_t l1, l2; - internal::manage_caching_sizes(GetAction, &l1, &l2); + std::ptrdiff_t l1, l2, l3; + internal::manage_caching_sizes(GetAction, &l1, &l2, &l3); return l2; } @@ -1902,9 +1945,9 @@ inline std::ptrdiff_t l2CacheSize() * for the algorithms working per blocks. * * \sa computeProductBlockingSizes */ -inline void setCpuCacheSizes(std::ptrdiff_t l1, std::ptrdiff_t l2) +inline void setCpuCacheSizes(std::ptrdiff_t l1, std::ptrdiff_t l2, std::ptrdiff_t l3) { - internal::manage_caching_sizes(SetAction, &l1, &l2); + internal::manage_caching_sizes(SetAction, &l1, &l2, &l3); } } // end namespace Eigen |