diff options
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrix.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 4d7a6270a..49362adbe 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -299,7 +299,7 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M public: - gemm_blocking_space(DenseIndex /*rows*/, DenseIndex /*cols*/, DenseIndex /*depth*/, bool /*full_rows*/ = false) + gemm_blocking_space(DenseIndex /*rows*/, DenseIndex /*cols*/, DenseIndex /*depth*/, int /*num_threads*/, bool /*full_rows = false*/) { this->m_mc = ActualRows; this->m_nc = ActualCols; @@ -331,21 +331,21 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M public: - gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, bool full_rows = false) + gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, int num_threads, bool l3_blocking) { this->m_mc = Transpose ? cols : rows; this->m_nc = Transpose ? rows : cols; this->m_kc = depth; - if(full_rows) + if(l3_blocking) { - DenseIndex m = this->m_mc; - computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, m, this->m_nc); + computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, this->m_nc, num_threads); } - else // full columns + else // no l3 blocking { + DenseIndex m = this->m_mc; DenseIndex n = this->m_nc; - computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, n); + computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, m, n, num_threads); } m_sizeA = this->m_mc * this->m_kc; @@ -451,7 +451,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, _ActualLhsType, _ActualRhsType, Dest, BlockingType> GemmFunctor; - BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), true); + BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true); internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit); } |