aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h48
1 files changed, 33 insertions, 15 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
index a36eb2fe0..831089dee 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
@@ -42,13 +42,14 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
{
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride,
- const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, const ResScalar& alpha)
+ const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride,
+ const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
{
general_matrix_matrix_triangular_product<Index,
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
ColMajor, UpLo==Lower?Upper:Lower>
- ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha);
+ ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking);
}
};
@@ -58,7 +59,8 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
{
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride,
- const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride, const ResScalar& alpha)
+ const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride,
+ const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
{
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
@@ -69,16 +71,18 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
RhsMapper rhs(_rhs,rhsStride);
ResMapper res(_res, resStride);
- Index kc = depth; // cache block size along the K direction
- Index mc = size; // cache block size along the M direction
- Index nc = size; // cache block size along the N direction
- computeProductBlockingSizes<LhsScalar,RhsScalar>(kc, mc, nc, 1);
+ Index kc = blocking.kc();
+ Index mc = (std::min)(size,blocking.mc());
+
// !!! mc must be a multiple of nr:
if(mc > Traits::nr)
mc = (mc/Traits::nr)*Traits::nr;
- ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, kc*mc, 0);
- ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, kc*size, 0);
+ std::size_t sizeA = kc*mc;
+ std::size_t sizeB = kc*size;
+
+ ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA());
+ ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB());
gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
@@ -136,7 +140,7 @@ struct tribb_kernel
typedef typename Traits::ResScalar ResScalar;
enum {
- BlockSize = EIGEN_PLAIN_ENUM_MAX(mr,nr)
+ BlockSize = meta_least_common_multiple<EIGEN_PLAIN_ENUM_MAX(mr,nr),EIGEN_PLAIN_ENUM_MIN(mr,nr)>::ret
};
void operator()(ResScalar* _res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha)
{
@@ -256,13 +260,27 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
+ enum {
+ IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0,
+ LhsIsRowMajor = _ActualLhs::Flags&RowMajorBit ? 1 : 0,
+ RhsIsRowMajor = _ActualRhs::Flags&RowMajorBit ? 1 : 0
+ };
+
+ Index size = mat.cols();
+ Index depth = actualLhs.cols();
+
+ typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor,typename Lhs::Scalar,typename Rhs::Scalar,
+ MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime, _ActualRhs::MaxColsAtCompileTime> BlockingType;
+
+ BlockingType blocking(size, size, depth, 1, false);
+
internal::general_matrix_matrix_triangular_product<Index,
- typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
- typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
- MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo>
- ::run(mat.cols(), actualLhs.cols(),
+ typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
+ typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
+ IsRowMajor ? RowMajor : ColMajor, UpLo>
+ ::run(size, depth,
&actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
- mat.data(), mat.outerStride(), actualAlpha);
+ mat.data(), mat.outerStride(), actualAlpha, blocking);
}
};