From 6c7ab508117d84671054808c921980a4908efb20 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 21 Feb 2014 16:43:03 +0100 Subject: Get rid of GeneralProduct<> for GemmProduct --- Eigen/Core | 7 ++-- Eigen/src/Core/ProductEvaluators.h | 14 ------- Eigen/src/Core/products/GeneralMatrixMatrix.h | 57 +++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 18 deletions(-) diff --git a/Eigen/Core b/Eigen/Core index bcf354a48..245604465 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -371,6 +371,9 @@ using std::ptrdiff_t; #include "src/Core/products/GeneralBlockPanelKernel.h" #include "src/Core/products/Parallelizer.h" #include "src/Core/products/CoeffBasedProduct.h" +#ifdef EIGEN_ENABLE_EVALUATORS +#include "src/Core/ProductEvaluators.h" +#endif #include "src/Core/products/GeneralMatrixVector.h" #include "src/Core/products/GeneralMatrixMatrix.h" #include "src/Core/SolveTriangular.h" @@ -386,10 +389,6 @@ using std::ptrdiff_t; #include "src/Core/BandMatrix.h" #include "src/Core/CoreIterators.h" -#ifdef EIGEN_ENABLE_EVALUATORS -#include "src/Core/ProductEvaluators.h" -#endif - #include "src/Core/BooleanRedux.h" #include "src/Core/Select.h" #include "src/Core/VectorwiseOp.h" diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 93ae5f5f5..99c48ae52 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -311,20 +311,6 @@ struct generic_product_impl } }; -template -struct generic_product_impl - : generic_product_impl_base > -{ - typedef typename Product::Scalar Scalar; - - template - static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) - { - // TODO bypass GeneralProduct class - GeneralProduct(lhs,rhs).scaleAndAddTo(dst, alpha); - } -}; - template struct generic_product_impl { diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 3f5ffcf51..1c8940e1c 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -374,6 +374,7 @@ class gemm_blocking_space class GeneralProduct : public ProductBase, Lhs, Rhs> @@ -421,6 +422,62 @@ class GeneralProduct internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit); } }; +#else // EIGEN_TEST_EVALUATORS +namespace internal { + +template +struct generic_product_impl + : generic_product_impl_base > +{ + typedef typename Product::Scalar Scalar; + typedef typename Product::Index Index; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + + typedef internal::blas_traits LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename internal::remove_all::type ActualLhsTypeCleaned; + + typedef internal::blas_traits RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all::type ActualRhsTypeCleaned; + + enum { + MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime) + }; + + template + static void scaleAndAddTo(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) + { + eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols()); + + typename internal::add_const_on_value_type::type lhs = LhsBlasTraits::extract(a_lhs); + typename internal::add_const_on_value_type::type rhs = RhsBlasTraits::extract(a_rhs); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) + * RhsBlasTraits::extractScalarFactor(a_rhs); + + typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, + Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; + + typedef internal::gemm_functor< + Scalar, Index, + internal::general_matrix_matrix_product< + Index, + LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), + RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), + (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, + ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor; + + BlockingType blocking(dst.rows(), dst.cols(), lhs.cols()); + + internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)> + (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), Dest::Flags&RowMajorBit); + } +}; + +} // end namespace internal +#endif // EIGEN_TEST_EVALUATORS } // end namespace Eigen -- cgit v1.2.3