From e58827d2ed32cee5362e4d7d007da06a2bdc7309 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 25 Jan 2016 17:16:33 +0100 Subject: bug #51: make general_matrix_matrix_triangular_product use L3-blocking helper so that general symmetric rank-updates and general-matrix-to-triangular products do not trigger dynamic memory allocation for fixed size matrices. --- .../Core/products/GeneralMatrixMatrixTriangular.h | 44 +++++++++++++++++----- Eigen/src/Core/products/SelfadjointProduct.h | 24 +++++++++--- test/nomalloc.cpp | 9 +++-- 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h index a36eb2fe0..f80f3b410 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h @@ -42,13 +42,14 @@ struct general_matrix_matrix_triangular_product::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride, - const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, const ResScalar& alpha) + const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, + const ResScalar& alpha, level3_blocking& blocking) { general_matrix_matrix_triangular_product - ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha); + ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking); } }; @@ -58,7 +59,8 @@ struct general_matrix_matrix_triangular_product::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride, - const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride, const ResScalar& alpha) + const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride, + const ResScalar& alpha, level3_blocking& blocking) { typedef gebp_traits Traits; @@ -73,12 +75,20 @@ struct general_matrix_matrix_triangular_product(kc, mc, nc, 1); + + kc = blocking.kc(); + mc = (std::min)(size,blocking.mc()); + nc = (std::min)(size,blocking.nc()); + // !!! mc must be a multiple of nr: if(mc > Traits::nr) mc = (mc/Traits::nr)*Traits::nr; - ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, kc*mc, 0); - ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, kc*size, 0); + std::size_t sizeA = kc*mc; + std::size_t sizeB = kc*size; + + ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA()); + ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB()); gemm_pack_lhs pack_lhs; gemm_pack_rhs pack_rhs; @@ -256,13 +266,27 @@ struct general_product_to_triangular_selector typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + enum { + IsRowMajor = (internal::traits::Flags&RowMajorBit) ? 1 : 0, + LhsIsRowMajor = _ActualLhs::Flags&RowMajorBit ? 1 : 0, + RhsIsRowMajor = _ActualRhs::Flags&RowMajorBit ? 1 : 0 + }; + + Index size = mat.cols(); + Index depth = actualLhs.cols(); + + typedef internal::gemm_blocking_space BlockingType; + + BlockingType blocking(size, size, depth, 1, false); + internal::general_matrix_matrix_triangular_product - ::run(mat.cols(), actualLhs.cols(), + typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, + typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, + IsRowMajor ? RowMajor : ColMajor, UpLo> + ::run(size, depth, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), - mat.data(), mat.outerStride(), actualAlpha); + mat.data(), mat.outerStride(), actualAlpha, blocking); } }; diff --git a/Eigen/src/Core/products/SelfadjointProduct.h b/Eigen/src/Core/products/SelfadjointProduct.h index 2af00058d..f038d686f 100644 --- a/Eigen/src/Core/products/SelfadjointProduct.h +++ b/Eigen/src/Core/products/SelfadjointProduct.h @@ -92,15 +92,27 @@ struct selfadjoint_product_selector Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived()); - enum { IsRowMajor = (internal::traits::Flags&RowMajorBit) ? 1 : 0 }; + enum { + IsRowMajor = (internal::traits::Flags&RowMajorBit) ? 1 : 0, + OtherIsRowMajor = _ActualOtherType::Flags&RowMajorBit ? 1 : 0 + }; + + Index size = mat.cols(); + Index depth = actualOther.cols(); + + typedef internal::gemm_blocking_space BlockingType; + + BlockingType blocking(size, size, depth, 1, false); + internal::general_matrix_matrix_triangular_product::IsComplex, - Scalar, _ActualOtherType::Flags&RowMajorBit ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits::IsComplex, - MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> - ::run(mat.cols(), actualOther.cols(), + Scalar, OtherIsRowMajor ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits::IsComplex, + Scalar, OtherIsRowMajor ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits::IsComplex, + IsRowMajor ? RowMajor : ColMajor, UpLo> + ::run(size, depth, &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(), - mat.data(), mat.outerStride(), actualAlpha); + mat.data(), mat.outerStride(), actualAlpha, blocking); } }; diff --git a/test/nomalloc.cpp b/test/nomalloc.cpp index 060276a20..d85e9e5bc 100644 --- a/test/nomalloc.cpp +++ b/test/nomalloc.cpp @@ -81,11 +81,12 @@ template void nomalloc(const MatrixType& m) m2.template selfadjointView().rankUpdate(m1.row(0),-1); // The following fancy matrix-matrix products are not safe yet regarding static allocation -// m1 += m1.template triangularView() * m2.col(; -// m1.template selfadjointView().rankUpdate(m2); -// m1 += m1.template triangularView() * m2; +// m1.col(1) += m1.template triangularView() * m2.col(0); + m2.template selfadjointView().rankUpdate(m1); + m2 += m2.template triangularView() * m1; + m2.template triangularView() = m2 * m2; // m1 += m1.template selfadjointView() * m2; -// VERIFY_IS_APPROX(m1,m1); + VERIFY_IS_APPROX(m2,m2); } template -- cgit v1.2.3