aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2019-02-07 16:07:08 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2019-02-07 16:07:08 +0100
commitfa2fcb4895a4ae12cb28003e646c736d013e68e8 (patch)
tree57c973463fe0b798b2d235ae04455590a4479724 /Eigen/src/Core/products
parentb3c4344a6852e55c849976dd46ec4e861399bf16 (diff)
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h158
1 files changed, 29 insertions, 129 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 4bcccd326..f49abcad5 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -404,146 +404,26 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
namespace internal {
-template <typename Lhs, typename Rhs, typename Dest,
- bool MultipleRowsAtCompileTime =
- (Lhs::RowsAtCompileTime > 1 || Dest::RowsAtCompileTime > 1),
- bool MultipleColsAtCompileTime =
- (Rhs::ColsAtCompileTime > 1 || Dest::ColsAtCompileTime > 1)>
-struct gemm_selector {
- typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
- typedef internal::blas_traits<Lhs> LhsBlasTraits;
- typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
- typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
-
- typedef internal::blas_traits<Rhs> RhsBlasTraits;
- typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
- typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
-
- static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
- {
- if (a_rhs.cols() != 1 && a_lhs.rows() != 1) {
- gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha);
- } else if (a_rhs.cols() == 1) {
- // matrix * vector.
- internal::gemv_dense_selector<OnTheRight,
- (int(ActualLhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
- bool(internal::blas_traits<ActualLhsTypeCleaned>::HasUsableDirectAccess)
- >::run(a_lhs, a_rhs.col(0), dst, alpha);
- } else {
- // vector * matrix.
- internal::gemv_dense_selector<OnTheLeft,
- (int(ActualRhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
- bool(internal::blas_traits<ActualRhsTypeCleaned>::HasUsableDirectAccess)
- >::run(a_lhs.row(0), a_rhs, dst, alpha);
- }
- }
-};
-
-template <typename Lhs, typename Rhs, typename Dest>
-struct gemm_selector<Lhs, Rhs, Dest, true, false> {
- typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
- typedef internal::blas_traits<Lhs> LhsBlasTraits;
- typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
- typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
-
- static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
- {
- if (a_rhs.cols() != 1 && a_lhs.rows() != 1) {
- gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha);
- } else {
- // matrix * vector.
- internal::gemv_dense_selector<OnTheRight,
- (int(ActualLhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
- bool(internal::blas_traits<ActualLhsTypeCleaned>::HasUsableDirectAccess)
- >::run(a_lhs, a_rhs.col(0), dst, alpha);
- }
- }
-};
-
-template <typename Lhs, typename Rhs, typename Dest>
-struct gemm_selector<Lhs, Rhs, Dest, false, true> {
+template<typename Lhs, typename Rhs>
+struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
+ : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
+{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
- typedef internal::blas_traits<Rhs> RhsBlasTraits;
- typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
- typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
-
- static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
- {
- if (a_rhs.cols() != 1 && a_lhs.rows() != 1) {
- gemm_selector<Lhs, Rhs, Dest, true, true>::run(dst, a_lhs, a_rhs, alpha);
- } else {
- // vector * matrix.
- internal::gemv_dense_selector<OnTheLeft,
- (int(ActualRhsTypeCleaned::Flags)&RowMajorBit) ? RowMajor : ColMajor,
- bool(internal::blas_traits<ActualRhsTypeCleaned>::HasUsableDirectAccess)
- >::run(a_lhs.row(0), a_rhs, dst, alpha);
- }
- }
-};
-
-template <typename Lhs, typename Rhs, typename Dest>
-struct gemm_selector<Lhs, Rhs, Dest, true, true> {
- typedef typename Product<Lhs, Rhs>::Scalar Scalar;
typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar;
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
- typedef
- typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
+ typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
- typedef
- typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+ typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
enum {
- MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(
- Lhs::MaxColsAtCompileTime, Rhs::MaxRowsAtCompileTime)
+ MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime)
};
- static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs,
- const Scalar& alpha) {
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) *
- RhsBlasTraits::extractScalarFactor(a_rhs);
- typename internal::add_const_on_value_type<ActualLhsType>::type lhs =
- LhsBlasTraits::extract(a_lhs);
- typename internal::add_const_on_value_type<ActualRhsType>::type rhs =
- RhsBlasTraits::extract(a_rhs);
- typedef internal::gemm_blocking_space<
- (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor, LhsScalar, RhsScalar,
- Dest::MaxRowsAtCompileTime, Dest::MaxColsAtCompileTime,
- MaxDepthAtCompileTime>
- BlockingType;
-
- typedef internal::gemm_functor<
- Scalar, Index,
- internal::general_matrix_matrix_product<
- Index, LhsScalar,
- (ActualLhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor,
- bool(LhsBlasTraits::NeedToConjugate), RhsScalar,
- (ActualRhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor,
- bool(RhsBlasTraits::NeedToConjugate),
- (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor>,
- ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType>
- GemmFunctor;
-
- BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
- internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime > 32 ||
- Dest::MaxRowsAtCompileTime == Dynamic)>(
- GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(),
- a_rhs.cols(), a_lhs.cols(), Dest::Flags & RowMajorBit);
- }
-};
-
-template<typename Lhs, typename Rhs>
-struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
- : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
-{
- typedef typename Product<Lhs,Rhs>::Scalar Scalar;
typedef generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> lazyproduct;
template<typename Dst>
@@ -570,7 +450,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0)
lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op<typename Dst::Scalar,Scalar>());
else
- scaleAndAddTo(dst, lhs, rhs, Scalar(1));
+ scaleAndAddTo(dst,lhs, rhs, Scalar(1));
}
template<typename Dst>
@@ -589,7 +469,27 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0)
return;
- gemm_selector<Lhs, Rhs, Dest>::run(dst, a_lhs, a_rhs, alpha);
+ typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
+ typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
+
+ Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
+ * RhsBlasTraits::extractScalarFactor(a_rhs);
+
+ typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
+ Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
+
+ typedef internal::gemm_functor<
+ Scalar, Index,
+ internal::general_matrix_matrix_product<
+ Index,
+ LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
+ RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
+ (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
+ ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
+
+ BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
+ internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>
+ (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(), Dest::Flags&RowMajorBit);
}
};