diff options
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrix.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 31 |
1 files changed, 17 insertions, 14 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 90c9c4647..508c05c97 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -20,8 +20,9 @@ template<typename _LhsScalar, typename _RhsScalar> class level3_blocking; template< typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, - typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs> -struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor> + typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, + int ResInnerStride> +struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride> { typedef gebp_traits<RhsScalar,LhsScalar> Traits; @@ -30,7 +31,7 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh Index rows, Index cols, Index depth, const LhsScalar* lhs, Index lhsStride, const RhsScalar* rhs, Index rhsStride, - ResScalar* res, Index resStride, + ResScalar* res, Index resIncr, Index resStride, ResScalar alpha, level3_blocking<RhsScalar,LhsScalar>& blocking, GemmParallelInfo<Index>* info = 0) @@ -39,8 +40,8 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh general_matrix_matrix_product<Index, RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs, LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs, - ColMajor> - ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info); + ColMajor,ResInnerStride> + ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking,info); } }; @@ -49,8 +50,9 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh template< typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, - typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs> -struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor> + typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, + int ResInnerStride> +struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride> { typedef gebp_traits<LhsScalar,RhsScalar> Traits; @@ -59,17 +61,17 @@ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScala static void run(Index rows, Index cols, Index depth, const LhsScalar* _lhs, Index lhsStride, const RhsScalar* _rhs, Index rhsStride, - ResScalar* _res, Index resStride, + ResScalar* _res, Index resIncr, Index resStride, ResScalar alpha, level3_blocking<LhsScalar,RhsScalar>& blocking, GemmParallelInfo<Index>* info = 0) { typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper; - typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper; - LhsMapper lhs(_lhs,lhsStride); - RhsMapper rhs(_rhs,rhsStride); - ResMapper res(_res, resStride); + typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor,Unaligned,ResInnerStride> ResMapper; + LhsMapper lhs(_lhs, lhsStride); + RhsMapper rhs(_rhs, rhsStride); + ResMapper res(_res, resStride, resIncr); Index kc = blocking.kc(); // cache block size along the K direction Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction @@ -228,7 +230,7 @@ struct gemm_functor Gemm::run(rows, cols, m_lhs.cols(), &m_lhs.coeffRef(row,0), m_lhs.outerStride(), &m_rhs.coeffRef(0,col), m_rhs.outerStride(), - (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(), + (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.innerStride(), m_dest.outerStride(), m_actualAlpha, m_blocking, info); } @@ -498,7 +500,8 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> Index, LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), - (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, + (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor, + Dest::InnerStrideAtCompileTime>, ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor; BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true); |