diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-06-24 21:44:24 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-06-24 21:44:24 +0200 |
commit | 566867428c6134381c846d98ab86f9d2cc89a3c2 (patch) | |
tree | 89c8807654b80eb9a27d3e9286265a8b88377054 /Eigen/src | |
parent | e039edcb422e3b5c6c0c06e1a5ba69a22695ebe8 (diff) |
- add a low level mechanism to provide preallocated memory to gemm
- ensure static allocation for the product of "large" fixed size matrix
Diffstat (limited to 'Eigen/src')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 220 | ||||
-rw-r--r-- | Eigen/src/Core/products/Parallelizer.h | 13 |
2 files changed, 185 insertions, 48 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 062d75ba9..e614232f7 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -25,6 +25,8 @@ #ifndef EIGEN_GENERAL_MATRIX_MATRIX_H #define EIGEN_GENERAL_MATRIX_MATRIX_H +template<typename _LhsScalar, typename _RhsScalar> class ei_level3_blocking; + /* Specialization for a row-major destination matrix => simple transposition of the product */ template< typename Scalar, typename Index, @@ -38,7 +40,8 @@ struct ei_general_matrix_matrix_product<Scalar,Index,LhsStorageOrder,ConjugateLh const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, - GemmParallelInfo<Scalar, Index>* info = 0) + ei_level3_blocking<Scalar,Scalar>& blocking, + GemmParallelInfo<Index>* info = 0) { // transpose the product such that the result is column major ei_general_matrix_matrix_product<Scalar, Index, @@ -47,7 +50,7 @@ struct ei_general_matrix_matrix_product<Scalar,Index,LhsStorageOrder,ConjugateLh LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs, ColMajor> - ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,info); + ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info); } }; @@ -64,7 +67,8 @@ static void run(Index rows, Index cols, Index depth, const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, - GemmParallelInfo<Scalar,Index>* info = 0) + ei_level3_blocking<Scalar,Scalar>& blocking, + GemmParallelInfo<Index>* info = 0) { ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); @@ -75,10 +79,9 @@ static void run(Index rows, Index cols, Index depth, typedef typename ei_packet_traits<Scalar>::type PacketType; typedef ei_product_blocking_traits<Scalar> Blocking; - Index kc = depth; // cache block size along the K direction - Index mc = rows; // cache block size along the M direction - Index nc = cols; // cache block size along the N direction - computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc); + Index kc = blocking.kc(); // cache block size along the K direction + Index mc = std::min(rows,blocking.mc()); // cache block size along the M direction + //Index nc = blocking.nc(); // cache block size along the N direction ei_gemm_pack_rhs<Scalar, Index, Blocking::nr, RhsStorageOrder> pack_rhs; ei_gemm_pack_lhs<Scalar, Index, Blocking::mr, LhsStorageOrder> pack_lhs; @@ -94,10 +97,10 @@ static void run(Index rows, Index cols, Index depth, Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); std::size_t sizeW = kc*Blocking::PacketSize*Blocking::nr*8; Scalar* w = ei_aligned_stack_new(Scalar, sizeW); - Scalar* blockB = (Scalar*)info[tid].blockB; + Scalar* blockB = blocking.blockB(); + ei_internal_assert(blockB!=0); - // For each horizontal panel of the rhs, and corresponding panel of the lhs... - // (==GEMM_VAR1) + // For each horizontal panel of the rhs, and corresponding vertical panel of the lhs... for(Index k=0; k<depth; k+=kc) { const Index actual_kc = std::min(k+kc,depth)-k; // => rows of B', and cols of the A' @@ -106,7 +109,7 @@ static void run(Index rows, Index cols, Index depth, // let's start by packing A'. pack_lhs(blockA, &lhs(0,k), lhsStride, actual_kc, mc); - // Pack B_k to B' in parallel fashion: + // Pack B_k to B' in a parallel fashion: // each thread packs the sub block B_k,j to B'_j where j is the thread id. // However, before copying to B'_j, we have to make sure that no other thread is still using it, @@ -162,10 +165,12 @@ static void run(Index rows, Index cols, Index depth, EIGEN_UNUSED_VARIABLE(info); // this is the sequential version! - Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); - std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*cols; - Scalar* allocatedBlockB = ei_aligned_stack_new(Scalar, sizeB); - Scalar* blockB = allocatedBlockB + kc*Blocking::PacketSize*Blocking::nr; + std::size_t sizeA = kc*mc; + std::size_t sizeB = kc*cols; + std::size_t sizeW = kc*Blocking::PacketSize*Blocking::nr; + Scalar *blockA = blocking.blockA()==0 ? ei_aligned_stack_new(Scalar, sizeA) : blocking.blockA(); + Scalar *blockB = blocking.blockB()==0 ? ei_aligned_stack_new(Scalar, sizeB) : blocking.blockB(); + Scalar *blockW = blocking.blockW()==0 ? ei_aligned_stack_new(Scalar, sizeW) : blocking.blockW(); // For each horizontal panel of the rhs, and corresponding panel of the lhs... // (==GEMM_VAR1) @@ -192,13 +197,14 @@ static void run(Index rows, Index cols, Index depth, pack_lhs(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc); // Everything is packed, we can now call the block * panel kernel: - gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols); + gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, -1, -1, 0, 0, blockW); } } - ei_aligned_stack_delete(Scalar, blockA, kc*mc); - ei_aligned_stack_delete(Scalar, allocatedBlockB, sizeB); + if(blocking.blockA()==0) ei_aligned_stack_delete(Scalar, blockA, kc*mc); + if(blocking.blockB()==0) ei_aligned_stack_delete(Scalar, blockB, sizeB); + if(blocking.blockW()==0) ei_aligned_stack_delete(Scalar, blockW, sizeW); } } @@ -214,33 +220,25 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> > : ei_traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> > {}; -template<typename Scalar, typename Index, typename Gemm, typename Lhs, typename Rhs, typename Dest> +template<typename Scalar, typename Index, typename Gemm, typename Lhs, typename Rhs, typename Dest, typename BlockingType> struct ei_gemm_functor { - typedef typename Rhs::Scalar BlockBScalar; - - ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha) - : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha) + ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha, + BlockingType& blocking) + : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha), m_blocking(blocking) {} - void operator() (Index row, Index rows, Index col=0, Index cols=-1, GemmParallelInfo<BlockBScalar,Index>* info=0) const + void operator() (Index row, Index rows, Index col=0, Index cols=-1, GemmParallelInfo<Index>* info=0) const { if(cols==-1) cols = m_rhs.cols(); + if(info) + m_blocking.allocateB(); Gemm::run(rows, cols, m_lhs.cols(), (const Scalar*)&(m_lhs.const_cast_derived().coeffRef(row,0)), m_lhs.outerStride(), (const Scalar*)&(m_rhs.const_cast_derived().coeffRef(0,col)), m_rhs.outerStride(), (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(), - m_actualAlpha, - info); - } - - - Index sharedBlockBSize() const - { - Index kc = m_rhs.rows(), mc = m_lhs.rows(), nc = m_rhs.cols(); - computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc); - return kc * nc;; + m_actualAlpha, m_blocking, info); } protected: @@ -248,12 +246,155 @@ struct ei_gemm_functor const Rhs& m_rhs; Dest& m_dest; Scalar m_actualAlpha; + BlockingType& m_blocking; +}; + +template<int StorageOrder, typename LhsScalar, typename RhsScalar, int MaxRows, int MaxCols, int MaxDepth, +bool FiniteAtCompileTime = MaxRows!=Dynamic && MaxCols!=Dynamic && MaxDepth != Dynamic> struct ei_gemm_blocking_space; + +template<typename _LhsScalar, typename _RhsScalar> +class ei_level3_blocking +{ + typedef _LhsScalar LhsScalar; + typedef _RhsScalar RhsScalar; + + protected: + LhsScalar* m_blockA; + RhsScalar* m_blockB; + RhsScalar* m_blockW; + + DenseIndex m_mc; + DenseIndex m_nc; + DenseIndex m_kc; + + public: + + ei_level3_blocking() + : m_blockA(0), m_blockB(0), m_blockW(0), m_mc(0), m_nc(0), m_kc(0) + {} + + inline DenseIndex mc() const { return m_mc; } + inline DenseIndex nc() const { return m_nc; } + inline DenseIndex kc() const { return m_kc; } + + inline LhsScalar* blockA() { return m_blockA; } + inline RhsScalar* blockB() { return m_blockB; } + inline RhsScalar* blockW() { return m_blockW; } +}; + +template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth> +class ei_gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, true> + : public ei_level3_blocking< + typename ei_meta_if<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::ret, + typename ei_meta_if<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::ret> +{ + enum { + Transpose = StorageOrder==RowMajor, + ActualRows = Transpose ? MaxCols : MaxRows, + ActualCols = Transpose ? MaxRows : MaxCols + }; + typedef typename ei_meta_if<Transpose,_RhsScalar,_LhsScalar>::ret LhsScalar; + typedef typename ei_meta_if<Transpose,_LhsScalar,_RhsScalar>::ret RhsScalar; + typedef ei_product_blocking_traits<RhsScalar> Blocking; + enum { + SizeA = ActualCols * MaxDepth, + SizeB = ActualRows * MaxDepth, + SizeW = MaxDepth * Blocking::nr * ei_packet_traits<RhsScalar>::size + }; + + EIGEN_ALIGN16 LhsScalar m_staticA[SizeA]; + EIGEN_ALIGN16 RhsScalar m_staticB[SizeB]; + EIGEN_ALIGN16 RhsScalar m_staticW[SizeW]; + + public: + + ei_gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth) + { + this->m_mc = ActualRows; + this->m_nc = ActualCols; + this->m_kc = MaxDepth; + this->m_blockA = m_staticA; + this->m_blockB = m_staticB; + this->m_blockW = m_staticW; + } + + inline void allocateA() {} + inline void allocateB() {} + inline void allocateW() {} + inline void allocateAll() {} +}; + +template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth> +struct ei_gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, false> + : public ei_level3_blocking< + typename ei_meta_if<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::ret, + typename ei_meta_if<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::ret> +{ + enum { + Transpose = StorageOrder==RowMajor + }; + typedef typename ei_meta_if<Transpose,_RhsScalar,_LhsScalar>::ret LhsScalar; + typedef typename ei_meta_if<Transpose,_LhsScalar,_RhsScalar>::ret RhsScalar; + typedef ei_product_blocking_traits<RhsScalar> Blocking; + + DenseIndex m_sizeA; + DenseIndex m_sizeB; + DenseIndex m_sizeW; + + public: + + ei_gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth) + { + this->m_mc = Transpose ? cols : rows; + this->m_nc = Transpose ? rows : cols; + this->m_kc = depth; + + computeProductBlockingSizes<LhsScalar,RhsScalar>(this->m_kc, this->m_mc, this->m_nc); + m_sizeA = this->m_mc * this->m_kc; + m_sizeB = this->m_kc * this->m_nc; + m_sizeW = this->m_kc*ei_packet_traits<RhsScalar>::size*Blocking::nr; + } + + void allocateA() + { + if(this->m_blockA==0) + this->m_blockA = ei_aligned_new<LhsScalar>(m_sizeA); + } + + void allocateB() + { + if(this->m_blockB==0) + this->m_blockB = ei_aligned_new<RhsScalar>(m_sizeB); + } + + void allocateW() + { + if(this->m_blockB==0) + this->m_blockB = ei_aligned_new<RhsScalar>(m_sizeB); + } + + void allocateAll() + { + allocateA(); + allocateB(); + allocateW(); + } + + ~ei_gemm_blocking_space() + { + ei_aligned_delete(this->m_blockA, m_sizeA); + ei_aligned_delete(this->m_blockB, m_sizeB); + ei_aligned_delete(this->m_blockW, m_sizeW); + } }; template<typename Lhs, typename Rhs> class GeneralProduct<Lhs, Rhs, GemmProduct> : public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> { + enum { + MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime) + }; public: EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct) @@ -273,6 +414,9 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) * RhsBlasTraits::extractScalarFactor(m_rhs); + typedef ei_gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, + Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; + typedef ei_gemm_functor< Scalar, Index, ei_general_matrix_matrix_product< @@ -280,11 +424,11 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, - _ActualLhsType, - _ActualRhsType, - Dest> GemmFunctor; + _ActualLhsType, _ActualRhsType, Dest, BlockingType> GemmFunctor; + + BlockingType blocking(dst.rows(), dst.cols(), lhs.cols()); - ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols()); + ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols()); } }; diff --git a/Eigen/src/Core/products/Parallelizer.h b/Eigen/src/Core/products/Parallelizer.h index 750fa7b5f..c51851121 100644 --- a/Eigen/src/Core/products/Parallelizer.h +++ b/Eigen/src/Core/products/Parallelizer.h @@ -69,16 +69,15 @@ inline void setNbThreads(int v) ei_manage_multi_threading(SetAction, &v); } -template<typename BlockBScalar, typename Index> struct GemmParallelInfo +template<typename Index> struct GemmParallelInfo { - GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0), blockB(0) {} + GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0) {} int volatile sync; int volatile users; Index rhs_start; Index rhs_length; - BlockBScalar* blockB; }; template<bool Condition, typename Functor, typename Index> @@ -112,11 +111,7 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols) Index blockCols = (cols / threads) & ~Index(0x3); Index blockRows = (rows / threads) & ~Index(0x7); - typedef typename Functor::BlockBScalar BlockBScalar; - BlockBScalar* sharedBlockB = new BlockBScalar[func.sharedBlockBSize()]; - - GemmParallelInfo<BlockBScalar,Index>* info = new - GemmParallelInfo<BlockBScalar,Index>[threads]; + GemmParallelInfo<Index>* info = new GemmParallelInfo<Index>[threads]; #pragma omp parallel for schedule(static,1) num_threads(threads) for(Index i=0; i<threads; ++i) @@ -129,12 +124,10 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols) info[i].rhs_start = c0; info[i].rhs_length = actualBlockCols; - info[i].blockB = sharedBlockB; func(r0, actualBlockRows, 0,cols, info); } - delete[] sharedBlockB; delete[] info; #endif } |