diff options
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h | 48 |
1 files changed, 33 insertions, 15 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h index a36eb2fe0..831089dee 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h @@ -42,13 +42,14 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder, { typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride, - const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, const ResScalar& alpha) + const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, + const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking) { general_matrix_matrix_triangular_product<Index, RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs, LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs, ColMajor, UpLo==Lower?Upper:Lower> - ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha); + ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking); } }; @@ -58,7 +59,8 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder, { typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride, - const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride, const ResScalar& alpha) + const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride, + const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking) { typedef gebp_traits<LhsScalar,RhsScalar> Traits; @@ -69,16 +71,18 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder, RhsMapper rhs(_rhs,rhsStride); ResMapper res(_res, resStride); - Index kc = depth; // cache block size along the K direction - Index mc = size; // cache block size along the M direction - Index nc = size; // cache block size along the N direction - computeProductBlockingSizes<LhsScalar,RhsScalar>(kc, mc, nc, 1); + Index kc = blocking.kc(); + Index mc = (std::min)(size,blocking.mc()); + // !!! mc must be a multiple of nr: if(mc > Traits::nr) mc = (mc/Traits::nr)*Traits::nr; - ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, kc*mc, 0); - ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, kc*size, 0); + std::size_t sizeA = kc*mc; + std::size_t sizeB = kc*size; + + ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA()); + ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB()); gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs; @@ -136,7 +140,7 @@ struct tribb_kernel typedef typename Traits::ResScalar ResScalar; enum { - BlockSize = EIGEN_PLAIN_ENUM_MAX(mr,nr) + BlockSize = meta_least_common_multiple<EIGEN_PLAIN_ENUM_MAX(mr,nr),EIGEN_PLAIN_ENUM_MIN(mr,nr)>::ret }; void operator()(ResScalar* _res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha) { @@ -256,13 +260,27 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false> typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + enum { + IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0, + LhsIsRowMajor = _ActualLhs::Flags&RowMajorBit ? 1 : 0, + RhsIsRowMajor = _ActualRhs::Flags&RowMajorBit ? 1 : 0 + }; + + Index size = mat.cols(); + Index depth = actualLhs.cols(); + + typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor,typename Lhs::Scalar,typename Rhs::Scalar, + MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime, _ActualRhs::MaxColsAtCompileTime> BlockingType; + + BlockingType blocking(size, size, depth, 1, false); + internal::general_matrix_matrix_triangular_product<Index, - typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, - typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, - MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> - ::run(mat.cols(), actualLhs.cols(), + typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, + typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, + IsRowMajor ? RowMajor : ColMajor, UpLo> + ::run(size, depth, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), - mat.data(), mat.outerStride(), actualAlpha); + mat.data(), mat.outerStride(), actualAlpha, blocking); } }; |