aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products
diff options
context:
space:
mode:
authorGravatar Rasmus Larsen <rmlarsen@google.com>2019-02-02 01:53:44 +0000
committerGravatar Rasmus Larsen <rmlarsen@google.com>2019-02-02 01:53:44 +0000
commite7b481ea7460e29e7cefd2d5c5bf527e163bb7f7 (patch)
tree05e0f6b2e4f893625e0c15156110477ae592438f /Eigen/src/Core/products
parentb55b5c7280a0481f01fe5ec764d55c443a8b6496 (diff)
parent4c0fa6ce0f81ce67dd6723528ddf72f66ae92ba2 (diff)
Merged in rmlarsen/eigen (pull request PR-578)
Speed up Eigen matrix*vector and vector*matrix multiplication. Approved-by: Eugene Zhulenev <ezhulenev@google.com>
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h158
1 files changed, 129 insertions, 29 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index f49abcad5..4bcccd326 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -404,13 +404,13 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
namespace internal {
-template<typename Lhs, typename Rhs>
-struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
- : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
-{
+template <typename Lhs, typename Rhs, typename Dest,
+ bool MultipleRowsAtCompileTime =
+ (Lhs::RowsAtCompileTime > 1 || Dest::RowsAtCompileTime > 1),
+ bool MultipleColsAtCompileTime =
+ (Rhs::ColsAtCompileTime > 1 || Dest::ColsAtCompileTime > 1)>
+struct gemm_selector {
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
- typedef typename Lhs::Scalar LhsScalar;
- typedef typename Rhs::Scalar RhsScalar;
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
@@ -420,10 +420,130 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+ static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
+ {
+ if (a_rhs.cols() != 1 && a_lhs.rows() != 1) {
+ gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha);
+ } else if (a_rhs.cols() == 1) {
+ // matrix * vector.
+ internal::gemv_dense_selector<OnTheRight,
+ (int(ActualLhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
+ bool(internal::blas_traits<ActualLhsTypeCleaned>::HasUsableDirectAccess)
+ >::run(a_lhs, a_rhs.col(0), dst, alpha);
+ } else {
+ // vector * matrix.
+ internal::gemv_dense_selector<OnTheLeft,
+ (int(ActualRhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
+ bool(internal::blas_traits<ActualRhsTypeCleaned>::HasUsableDirectAccess)
+ >::run(a_lhs.row(0), a_rhs, dst, alpha);
+ }
+ }
+};
+
+template <typename Lhs, typename Rhs, typename Dest>
+struct gemm_selector<Lhs, Rhs, Dest, true, false> {
+ typedef typename Product<Lhs,Rhs>::Scalar Scalar;
+
+ typedef internal::blas_traits<Lhs> LhsBlasTraits;
+ typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+ typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
+
+ static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
+ {
+ if (a_rhs.cols() != 1 && a_lhs.rows() != 1) {
+ gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha);
+ } else {
+ // matrix * vector.
+ internal::gemv_dense_selector<OnTheRight,
+ (int(ActualLhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
+ bool(internal::blas_traits<ActualLhsTypeCleaned>::HasUsableDirectAccess)
+ >::run(a_lhs, a_rhs.col(0), dst, alpha);
+ }
+ }
+};
+
+template <typename Lhs, typename Rhs, typename Dest>
+struct gemm_selector<Lhs, Rhs, Dest, false, true> {
+ typedef typename Product<Lhs,Rhs>::Scalar Scalar;
+
+ typedef internal::blas_traits<Rhs> RhsBlasTraits;
+ typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+ typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+
+ static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
+ {
+ if (a_rhs.cols() != 1 && a_lhs.rows() != 1) {
+ gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha);
+ } else {
+ // vector * matrix.
+ internal::gemv_dense_selector<OnTheLeft,
+ (int(ActualRhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
+ bool(internal::blas_traits<ActualRhsTypeCleaned>::HasUsableDirectAccess)
+ >::run(a_lhs.row(0), a_rhs, dst, alpha);
+ }
+ }
+};
+
+template <typename Lhs, typename Rhs, typename Dest>
+struct gemm_selector<Lhs, Rhs, Dest, true, true> {
+ typedef typename Product<Lhs, Rhs>::Scalar Scalar;
+ typedef typename Lhs::Scalar LhsScalar;
+ typedef typename Rhs::Scalar RhsScalar;
+
+ typedef internal::blas_traits<Lhs> LhsBlasTraits;
+ typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+ typedef
+ typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
+
+ typedef internal::blas_traits<Rhs> RhsBlasTraits;
+ typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+ typedef
+ typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+
enum {
- MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime)
+ MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(
+ Lhs::MaxColsAtCompileTime, Rhs::MaxRowsAtCompileTime)
};
+ static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs,
+ const Scalar& alpha) {
+ Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) *
+ RhsBlasTraits::extractScalarFactor(a_rhs);
+ typename internal::add_const_on_value_type<ActualLhsType>::type lhs =
+ LhsBlasTraits::extract(a_lhs);
+ typename internal::add_const_on_value_type<ActualRhsType>::type rhs =
+ RhsBlasTraits::extract(a_rhs);
+ typedef internal::gemm_blocking_space<
+ (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor, LhsScalar, RhsScalar,
+ Dest::MaxRowsAtCompileTime, Dest::MaxColsAtCompileTime,
+ MaxDepthAtCompileTime>
+ BlockingType;
+
+ typedef internal::gemm_functor<
+ Scalar, Index,
+ internal::general_matrix_matrix_product<
+ Index, LhsScalar,
+ (ActualLhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor,
+ bool(LhsBlasTraits::NeedToConjugate), RhsScalar,
+ (ActualRhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor,
+ bool(RhsBlasTraits::NeedToConjugate),
+ (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor>,
+ ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType>
+ GemmFunctor;
+
+ BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
+ internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime > 32 ||
+ Dest::MaxRowsAtCompileTime == Dynamic)>(
+ GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(),
+ a_rhs.cols(), a_lhs.cols(), Dest::Flags & RowMajorBit);
+ }
+};
+
+template<typename Lhs, typename Rhs>
+struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
+ : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
+{
+ typedef typename Product<Lhs,Rhs>::Scalar Scalar;
typedef generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> lazyproduct;
template<typename Dst>
@@ -450,7 +570,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0)
lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op<typename Dst::Scalar,Scalar>());
else
- scaleAndAddTo(dst,lhs, rhs, Scalar(1));
+ scaleAndAddTo(dst, lhs, rhs, Scalar(1));
}
template<typename Dst>
@@ -469,27 +589,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0)
return;
- typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
- typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
-
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
- * RhsBlasTraits::extractScalarFactor(a_rhs);
-
- typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
- Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
-
- typedef internal::gemm_functor<
- Scalar, Index,
- internal::general_matrix_matrix_product<
- Index,
- LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
- RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
- (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
- ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
-
- BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
- internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>
- (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(), Dest::Flags&RowMajorBit);
+ gemm_selector<Lhs, Rhs, Dest>::run(dst, a_lhs, a_rhs, alpha);
}
};