diff options
author | Rasmus Larsen <rmlarsen@google.com> | 2019-02-02 01:53:44 +0000 |
---|---|---|
committer | Rasmus Larsen <rmlarsen@google.com> | 2019-02-02 01:53:44 +0000 |
commit | e7b481ea7460e29e7cefd2d5c5bf527e163bb7f7 (patch) | |
tree | 05e0f6b2e4f893625e0c15156110477ae592438f /Eigen/src/Core/products | |
parent | b55b5c7280a0481f01fe5ec764d55c443a8b6496 (diff) | |
parent | 4c0fa6ce0f81ce67dd6723528ddf72f66ae92ba2 (diff) |
Merged in rmlarsen/eigen (pull request PR-578)
Speed up Eigen matrix*vector and vector*matrix multiplication.
Approved-by: Eugene Zhulenev <ezhulenev@google.com>
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 158 |
1 files changed, 129 insertions, 29 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index f49abcad5..4bcccd326 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -404,13 +404,13 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M namespace internal { -template<typename Lhs, typename Rhs> -struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> - : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> > -{ +template <typename Lhs, typename Rhs, typename Dest, + bool MultipleRowsAtCompileTime = + (Lhs::RowsAtCompileTime > 1 || Dest::RowsAtCompileTime > 1), + bool MultipleColsAtCompileTime = + (Rhs::ColsAtCompileTime > 1 || Dest::ColsAtCompileTime > 1)> +struct gemm_selector { typedef typename Product<Lhs,Rhs>::Scalar Scalar; - typedef typename Lhs::Scalar LhsScalar; - typedef typename Rhs::Scalar RhsScalar; typedef internal::blas_traits<Lhs> LhsBlasTraits; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; @@ -420,10 +420,130 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; + static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) + { + if (a_rhs.cols() != 1 && a_lhs.rows() != 1) { + gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha); + } else if (a_rhs.cols() == 1) { + // matrix * vector. + internal::gemv_dense_selector<OnTheRight, + (int(ActualLhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor, + bool(internal::blas_traits<ActualLhsTypeCleaned>::HasUsableDirectAccess) + >::run(a_lhs, a_rhs.col(0), dst, alpha); + } else { + // vector * matrix. + internal::gemv_dense_selector<OnTheLeft, + (int(ActualRhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor, + bool(internal::blas_traits<ActualRhsTypeCleaned>::HasUsableDirectAccess) + >::run(a_lhs.row(0), a_rhs, dst, alpha); + } + } +}; + +template <typename Lhs, typename Rhs, typename Dest> +struct gemm_selector<Lhs, Rhs, Dest, true, false> { + typedef typename Product<Lhs,Rhs>::Scalar Scalar; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned; + + static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) + { + if (a_rhs.cols() != 1 && a_lhs.rows() != 1) { + gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha); + } else { + // matrix * vector. + internal::gemv_dense_selector<OnTheRight, + (int(ActualLhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor, + bool(internal::blas_traits<ActualLhsTypeCleaned>::HasUsableDirectAccess) + >::run(a_lhs, a_rhs.col(0), dst, alpha); + } + } +}; + +template <typename Lhs, typename Rhs, typename Dest> +struct gemm_selector<Lhs, Rhs, Dest, false, true> { + typedef typename Product<Lhs,Rhs>::Scalar Scalar; + + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; + + static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) + { + if (a_rhs.cols() != 1 && a_lhs.rows() != 1) { + gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha); + } else { + // vector * matrix. + internal::gemv_dense_selector<OnTheLeft, + (int(ActualRhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor, + bool(internal::blas_traits<ActualRhsTypeCleaned>::HasUsableDirectAccess) + >::run(a_lhs.row(0), a_rhs, dst, alpha); + } + } +}; + +template <typename Lhs, typename Rhs, typename Dest> +struct gemm_selector<Lhs, Rhs, Dest, true, true> { + typedef typename Product<Lhs, Rhs>::Scalar Scalar; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef + typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned; + + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef + typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; + enum { - MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime) + MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED( + Lhs::MaxColsAtCompileTime, Rhs::MaxRowsAtCompileTime) }; + static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, + const Scalar& alpha) { + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) * + RhsBlasTraits::extractScalarFactor(a_rhs); + typename internal::add_const_on_value_type<ActualLhsType>::type lhs = + LhsBlasTraits::extract(a_lhs); + typename internal::add_const_on_value_type<ActualRhsType>::type rhs = + RhsBlasTraits::extract(a_rhs); + typedef internal::gemm_blocking_space< + (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor, LhsScalar, RhsScalar, + Dest::MaxRowsAtCompileTime, Dest::MaxColsAtCompileTime, + MaxDepthAtCompileTime> + BlockingType; + + typedef internal::gemm_functor< + Scalar, Index, + internal::general_matrix_matrix_product< + Index, LhsScalar, + (ActualLhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor, + bool(LhsBlasTraits::NeedToConjugate), RhsScalar, + (ActualRhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor, + bool(RhsBlasTraits::NeedToConjugate), + (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor>, + ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> + GemmFunctor; + + BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true); + internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime > 32 || + Dest::MaxRowsAtCompileTime == Dynamic)>( + GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), + a_rhs.cols(), a_lhs.cols(), Dest::Flags & RowMajorBit); + } +}; + +template<typename Lhs, typename Rhs> +struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> + : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> > +{ + typedef typename Product<Lhs,Rhs>::Scalar Scalar; typedef generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> lazyproduct; template<typename Dst> @@ -450,7 +570,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0) lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op<typename Dst::Scalar,Scalar>()); else - scaleAndAddTo(dst,lhs, rhs, Scalar(1)); + scaleAndAddTo(dst, lhs, rhs, Scalar(1)); } template<typename Dst> @@ -469,27 +589,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0) return; - typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs); - typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs); - - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) - * RhsBlasTraits::extractScalarFactor(a_rhs); - - typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, - Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; - - typedef internal::gemm_functor< - Scalar, Index, - internal::general_matrix_matrix_product< - Index, - LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), - RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), - (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, - ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor; - - BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true); - internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)> - (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(), Dest::Flags&RowMajorBit); + gemm_selector<Lhs, Rhs, Dest>::run(dst, a_lhs, a_rhs, alpha); } }; |