diff options
author | 2014-02-21 16:43:03 +0100 | |
---|---|---|
committer | 2014-02-21 16:43:03 +0100 | |
commit | 6c7ab508117d84671054808c921980a4908efb20 (patch) | |
tree | 02a70acb7467a8d97a83c5df02f6679083e12dbe | |
parent | 728c3d2cb955a255cae5515197ae65dc83209509 (diff) |
Get rid of GeneralProduct<> for GemmProduct
-rw-r--r-- | Eigen/Core | 7 | ||||
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 14 | ||||
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 57 |
3 files changed, 60 insertions, 18 deletions
diff --git a/Eigen/Core b/Eigen/Core index bcf354a48..245604465 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -371,6 +371,9 @@ using std::ptrdiff_t; #include "src/Core/products/GeneralBlockPanelKernel.h" #include "src/Core/products/Parallelizer.h" #include "src/Core/products/CoeffBasedProduct.h" +#ifdef EIGEN_ENABLE_EVALUATORS +#include "src/Core/ProductEvaluators.h" +#endif #include "src/Core/products/GeneralMatrixVector.h" #include "src/Core/products/GeneralMatrixMatrix.h" #include "src/Core/SolveTriangular.h" @@ -386,10 +389,6 @@ using std::ptrdiff_t; #include "src/Core/BandMatrix.h" #include "src/Core/CoreIterators.h" -#ifdef EIGEN_ENABLE_EVALUATORS -#include "src/Core/ProductEvaluators.h" -#endif - #include "src/Core/BooleanRedux.h" #include "src/Core/Select.h" #include "src/Core/VectorwiseOp.h" diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 93ae5f5f5..99c48ae52 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -312,20 +312,6 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemvProduct> }; template<typename Lhs, typename Rhs> -struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> - : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> > -{ - typedef typename Product<Lhs,Rhs>::Scalar Scalar; - - template<typename Dest> - static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) - { - // TODO bypass GeneralProduct class - GeneralProduct<Lhs, Rhs, GemmProduct>(lhs,rhs).scaleAndAddTo(dst, alpha); - } -}; - -template<typename Lhs, typename Rhs> struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> { typedef typename Product<Lhs,Rhs>::Scalar Scalar; diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 3f5ffcf51..1c8940e1c 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -374,6 +374,7 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M } // end namespace internal +#ifndef EIGEN_TEST_EVALUATORS template<typename Lhs, typename Rhs> class GeneralProduct<Lhs, Rhs, GemmProduct> : public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> @@ -421,6 +422,62 @@ class GeneralProduct<Lhs, Rhs, GemmProduct> internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit); } }; +#else // EIGEN_TEST_EVALUATORS +namespace internal { + +template<typename Lhs, typename Rhs> +struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> + : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> > +{ + typedef typename Product<Lhs,Rhs>::Scalar Scalar; + typedef typename Product<Lhs,Rhs>::Index Index; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned; + + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; + + enum { + MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime) + }; + + template<typename Dest> + static void scaleAndAddTo(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) + { + eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols()); + + typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs); + typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) + * RhsBlasTraits::extractScalarFactor(a_rhs); + + typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, + Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; + + typedef internal::gemm_functor< + Scalar, Index, + internal::general_matrix_matrix_product< + Index, + LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), + RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), + (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, + ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor; + + BlockingType blocking(dst.rows(), dst.cols(), lhs.cols()); + + internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)> + (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), Dest::Flags&RowMajorBit); + } +}; + +} // end namespace internal +#endif // EIGEN_TEST_EVALUATORS } // end namespace Eigen |