diff options
Diffstat (limited to 'Eigen/src/Core/products/SelfadjointMatrixMatrix.h')
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixMatrix.h | 46 |
1 files changed, 23 insertions, 23 deletions
diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h index f84f54982..da6f82abc 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h @@ -291,7 +291,7 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,Co const Scalar* lhs, Index lhsStride, const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, - const Scalar& alpha) + const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking) { product_selfadjoint_matrix<Scalar, Index, EIGEN_LOGICAL_XOR(RhsSelfAdjoint,RhsStorageOrder==RowMajor) ? ColMajor : RowMajor, @@ -299,7 +299,7 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,Co EIGEN_LOGICAL_XOR(LhsSelfAdjoint,LhsStorageOrder==RowMajor) ? ColMajor : RowMajor, LhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsSelfAdjoint,ConjugateLhs), ColMajor> - ::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha); + ::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha, blocking); } }; @@ -314,7 +314,7 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs const Scalar* _lhs, Index lhsStride, const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, - const Scalar& alpha); + const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking); }; template <typename Scalar, typename Index, @@ -325,7 +325,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,t const Scalar* _lhs, Index lhsStride, const Scalar* _rhs, Index rhsStride, Scalar* _res, Index resStride, - const Scalar& alpha) + const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking) { Index size = rows; @@ -340,17 +340,14 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,t RhsMapper rhs(_rhs,rhsStride); ResMapper res(_res, resStride); - Index kc = size; // cache block size along the K direction - Index mc = rows; // cache block size along the M direction - Index nc = cols; // cache block size along the N direction - computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc, 1); - // kc must smaller than mc + Index kc = blocking.kc(); // cache block size along the K direction + Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction + // kc must be smaller than mc kc = (std::min)(kc,mc); - + std::size_t sizeA = kc*mc; std::size_t sizeB = kc*cols; - ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0); - ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0); - Scalar* blockB = allocatedBlockB; + ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; symm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; @@ -410,7 +407,7 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLh const Scalar* _lhs, Index lhsStride, const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, - const Scalar& alpha); + const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking); }; template <typename Scalar, typename Index, @@ -421,7 +418,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,f const Scalar* _lhs, Index lhsStride, const Scalar* _rhs, Index rhsStride, Scalar* _res, Index resStride, - const Scalar& alpha) + const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking) { Index size = cols; @@ -432,14 +429,12 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,f LhsMapper lhs(_lhs,lhsStride); ResMapper res(_res,resStride); - Index kc = size; // cache block size along the K direction - Index mc = rows; // cache block size along the M direction - Index nc = cols; // cache block size along the N direction - computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc, 1); + Index kc = blocking.kc(); // cache block size along the K direction + Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction + std::size_t sizeA = kc*mc; std::size_t sizeB = kc*cols; - ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0); - ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0); - Scalar* blockB = allocatedBlockB; + ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; @@ -498,6 +493,11 @@ struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,RhsMode,false> Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) * RhsBlasTraits::extractScalarFactor(a_rhs); + typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, + Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,1> BlockingType; + + BlockingType blocking(lhs.rows(), rhs.cols(), lhs.cols(), 1, false); + internal::product_selfadjoint_matrix<Scalar, Index, EIGEN_LOGICAL_XOR(LhsIsUpper,internal::traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsIsUpper,bool(LhsBlasTraits::NeedToConjugate)), @@ -509,7 +509,7 @@ struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,RhsMode,false> &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info &rhs.coeffRef(0,0), rhs.outerStride(), // rhs info &dst.coeffRef(0,0), dst.outerStride(), // result info - actualAlpha // alpha + actualAlpha, blocking // alpha ); } }; |