aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-02-21 16:43:03 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-02-21 16:43:03 +0100
commit6c7ab508117d84671054808c921980a4908efb20 (patch)
tree02a70acb7467a8d97a83c5df02f6679083e12dbe
parent728c3d2cb955a255cae5515197ae65dc83209509 (diff)
Get rid of GeneralProduct<> for GemmProduct
-rw-r--r--Eigen/Core7
-rw-r--r--Eigen/src/Core/ProductEvaluators.h14
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h57
3 files changed, 60 insertions, 18 deletions
diff --git a/Eigen/Core b/Eigen/Core
index bcf354a48..245604465 100644
--- a/Eigen/Core
+++ b/Eigen/Core
@@ -371,6 +371,9 @@ using std::ptrdiff_t;
#include "src/Core/products/GeneralBlockPanelKernel.h"
#include "src/Core/products/Parallelizer.h"
#include "src/Core/products/CoeffBasedProduct.h"
+#ifdef EIGEN_ENABLE_EVALUATORS
+#include "src/Core/ProductEvaluators.h"
+#endif
#include "src/Core/products/GeneralMatrixVector.h"
#include "src/Core/products/GeneralMatrixMatrix.h"
#include "src/Core/SolveTriangular.h"
@@ -386,10 +389,6 @@ using std::ptrdiff_t;
#include "src/Core/BandMatrix.h"
#include "src/Core/CoreIterators.h"
-#ifdef EIGEN_ENABLE_EVALUATORS
-#include "src/Core/ProductEvaluators.h"
-#endif
-
#include "src/Core/BooleanRedux.h"
#include "src/Core/Select.h"
#include "src/Core/VectorwiseOp.h"
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index 93ae5f5f5..99c48ae52 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -312,20 +312,6 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemvProduct>
};
template<typename Lhs, typename Rhs>
-struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
- : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
-{
- typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
- template<typename Dest>
- static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
- {
- // TODO bypass GeneralProduct class
- GeneralProduct<Lhs, Rhs, GemmProduct>(lhs,rhs).scaleAndAddTo(dst, alpha);
- }
-};
-
-template<typename Lhs, typename Rhs>
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 3f5ffcf51..1c8940e1c 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -374,6 +374,7 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
} // end namespace internal
+#ifndef EIGEN_TEST_EVALUATORS
template<typename Lhs, typename Rhs>
class GeneralProduct<Lhs, Rhs, GemmProduct>
: public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
@@ -421,6 +422,62 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit);
}
};
+#else // EIGEN_TEST_EVALUATORS
+namespace internal {
+
+template<typename Lhs, typename Rhs>
+struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
+ : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
+{
+ typedef typename Product<Lhs,Rhs>::Scalar Scalar;
+ typedef typename Product<Lhs,Rhs>::Index Index;
+ typedef typename Lhs::Scalar LhsScalar;
+ typedef typename Rhs::Scalar RhsScalar;
+
+ typedef internal::blas_traits<Lhs> LhsBlasTraits;
+ typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+ typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
+
+ typedef internal::blas_traits<Rhs> RhsBlasTraits;
+ typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+ typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+
+ enum {
+ MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime)
+ };
+
+ template<typename Dest>
+ static void scaleAndAddTo(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
+ {
+ eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
+
+ typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
+ typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
+
+ Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
+ * RhsBlasTraits::extractScalarFactor(a_rhs);
+
+ typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
+ Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
+
+ typedef internal::gemm_functor<
+ Scalar, Index,
+ internal::general_matrix_matrix_product<
+ Index,
+ LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
+ RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
+ (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
+ ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
+
+ BlockingType blocking(dst.rows(), dst.cols(), lhs.cols());
+
+ internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>
+ (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), Dest::Flags&RowMajorBit);
+ }
+};
+
+} // end namespace internal
+#endif // EIGEN_TEST_EVALUATORS
} // end namespace Eigen