aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products/GeneralMatrixMatrix.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrix.h')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h142
1 files changed, 75 insertions, 67 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 49362adbe..fd9443cd2 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -228,8 +228,8 @@ struct gemm_functor
cols = m_rhs.cols();
Gemm::run(rows, cols, m_lhs.cols(),
- /*(const Scalar*)*/&m_lhs.coeffRef(row,0), m_lhs.outerStride(),
- /*(const Scalar*)*/&m_rhs.coeffRef(0,col), m_rhs.outerStride(),
+ &m_lhs.coeffRef(row,0), m_lhs.outerStride(),
+ &m_rhs.coeffRef(0,col), m_rhs.outerStride(),
(Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(),
m_actualAlpha, m_blocking, info);
}
@@ -379,84 +379,92 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
} // end namespace internal
+namespace internal {
+
template<typename Lhs, typename Rhs>
-class GeneralProduct<Lhs, Rhs, GemmProduct>
- : public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
+struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
+ : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
{
- enum {
- MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime)
- };
- public:
- EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct)
-
- typedef typename Lhs::Scalar LhsScalar;
- typedef typename Rhs::Scalar RhsScalar;
- typedef Scalar ResScalar;
-
- GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
- {
- typedef internal::scalar_product_op<LhsScalar,RhsScalar> BinOp;
- EIGEN_CHECK_BINARY_COMPATIBILIY(BinOp,LhsScalar,RhsScalar);
- }
-
- template<typename Dest>
- inline void evalTo(Dest& dst) const
- {
- if((m_rhs.rows()+dst.rows()+dst.cols())<20 && m_rhs.rows()>0)
- dst.noalias() = m_lhs .lazyProduct( m_rhs );
- else
- {
- dst.setZero();
- scaleAndAddTo(dst,Scalar(1));
- }
- }
-
- template<typename Dest>
- inline void addTo(Dest& dst) const
+ typedef typename Product<Lhs,Rhs>::Scalar Scalar;
+ typedef typename Product<Lhs,Rhs>::Index Index;
+ typedef typename Lhs::Scalar LhsScalar;
+ typedef typename Rhs::Scalar RhsScalar;
+
+ typedef internal::blas_traits<Lhs> LhsBlasTraits;
+ typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+ typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
+
+ typedef internal::blas_traits<Rhs> RhsBlasTraits;
+ typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+ typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+
+ enum {
+ MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime)
+ };
+
+ typedef generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> lazyproduct;
+
+ template<typename Dst>
+ static void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ {
+ if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0)
+ lazyproduct::evalTo(dst, lhs, rhs);
+ else
{
- if((m_rhs.rows()+dst.rows()+dst.cols())<20 && m_rhs.rows()>0)
- dst.noalias() += m_lhs .lazyProduct( m_rhs );
- else
- scaleAndAddTo(dst,Scalar(1));
+ dst.setZero();
+ scaleAndAddTo(dst, lhs, rhs, Scalar(1));
}
+ }
- template<typename Dest>
- inline void subTo(Dest& dst) const
- {
- if((m_rhs.rows()+dst.rows()+dst.cols())<20 && m_rhs.rows()>0)
- dst.noalias() -= m_lhs .lazyProduct( m_rhs );
- else
- scaleAndAddTo(dst,Scalar(-1));
- }
-
- template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
- {
- eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
+ template<typename Dst>
+ static void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ {
+ if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0)
+ lazyproduct::addTo(dst, lhs, rhs);
+ else
+ scaleAndAddTo(dst,lhs, rhs, Scalar(1));
+ }
- typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs);
- typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs);
+ template<typename Dst>
+ static void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
+ {
+ if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0)
+ lazyproduct::subTo(dst, lhs, rhs);
+ else
+ scaleAndAddTo(dst, lhs, rhs, Scalar(-1));
+ }
+
+ template<typename Dest>
+ static void scaleAndAddTo(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
+ {
+ eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
- * RhsBlasTraits::extractScalarFactor(m_rhs);
+ typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
+ typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
- typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
- Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
+ Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
+ * RhsBlasTraits::extractScalarFactor(a_rhs);
- typedef internal::gemm_functor<
- Scalar, Index,
- internal::general_matrix_matrix_product<
- Index,
- LhsScalar, (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
- RhsScalar, (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
- (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
- _ActualLhsType, _ActualRhsType, Dest, BlockingType> GemmFunctor;
+ typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
+ Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
- BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
+ typedef internal::gemm_functor<
+ Scalar, Index,
+ internal::general_matrix_matrix_product<
+ Index,
+ LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
+ RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
+ (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
+ ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
- internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit);
- }
+ BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
+ internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>
+ (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), Dest::Flags&RowMajorBit);
+ }
};
+} // end namespace internal
+
} // end namespace Eigen
#endif // EIGEN_GENERAL_MATRIX_MATRIX_H