diff options
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrix.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrix.h | 141 |
1 files changed, 75 insertions, 66 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 6ad07eccb..b7e1867f0 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -216,8 +216,8 @@ struct gemm_functor cols = m_rhs.cols(); Gemm::run(rows, cols, m_lhs.cols(), - /*(const Scalar*)*/&m_lhs.coeffRef(row,0), m_lhs.outerStride(), - /*(const Scalar*)*/&m_rhs.coeffRef(0,col), m_rhs.outerStride(), + &m_lhs.coeffRef(row,0), m_lhs.outerStride(), + &m_rhs.coeffRef(0,col), m_rhs.outerStride(), (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(), m_actualAlpha, m_blocking, info); } @@ -367,84 +367,93 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M } // end namespace internal +namespace internal { + template<typename Lhs, typename Rhs> -class GeneralProduct<Lhs, Rhs, GemmProduct> - : public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> +struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> + : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> > { - enum { - MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime) - }; - public: - EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct) - - typedef typename Lhs::Scalar LhsScalar; - typedef typename Rhs::Scalar RhsScalar; - typedef Scalar ResScalar; - - GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) - { - typedef internal::scalar_product_op<LhsScalar,RhsScalar> BinOp; - EIGEN_CHECK_BINARY_COMPATIBILIY(BinOp,LhsScalar,RhsScalar); - } - - template<typename Dest> - inline void evalTo(Dest& dst) const + typedef typename Product<Lhs,Rhs>::Scalar Scalar; + typedef typename Product<Lhs,Rhs>::Index Index; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned; + + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; + + enum { + MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime) + }; + + typedef generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> lazyproduct; + + template<typename Dst> + static void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) + { + if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0) + lazyproduct::evalTo(dst, lhs, rhs); + else { - if((m_rhs.rows()+dst.rows()+dst.cols())<20 && m_rhs.rows()>0) - dst.noalias() = m_lhs .lazyProduct( m_rhs ); - else - { - dst.setZero(); - scaleAndAddTo(dst,Scalar(1)); - } + dst.setZero(); + scaleAndAddTo(dst, lhs, rhs, Scalar(1)); } + } - template<typename Dest> - inline void addTo(Dest& dst) const - { - if((m_rhs.rows()+dst.rows()+dst.cols())<20 && m_rhs.rows()>0) - dst.noalias() += m_lhs .lazyProduct( m_rhs ); - else - scaleAndAddTo(dst,Scalar(1)); - } + template<typename Dst> + static void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) + { + if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0) + lazyproduct::addTo(dst, lhs, rhs); + else + scaleAndAddTo(dst,lhs, rhs, Scalar(1)); + } - template<typename Dest> - inline void subTo(Dest& dst) const - { - if((m_rhs.rows()+dst.rows()+dst.cols())<20 && m_rhs.rows()>0) - dst.noalias() -= m_lhs .lazyProduct( m_rhs ); - else - scaleAndAddTo(dst,Scalar(-1)); - } - - template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const - { - eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); + template<typename Dst> + static void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) + { + if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0) + lazyproduct::subTo(dst, lhs, rhs); + else + scaleAndAddTo(dst, lhs, rhs, Scalar(-1)); + } + + template<typename Dest> + static void scaleAndAddTo(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) + { + eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols()); - typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs); - typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs); + typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs); + typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs); - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) - * RhsBlasTraits::extractScalarFactor(m_rhs); + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) + * RhsBlasTraits::extractScalarFactor(a_rhs); - typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, - Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; + typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, + Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; - typedef internal::gemm_functor< - Scalar, Index, - internal::general_matrix_matrix_product< - Index, - LhsScalar, (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), - RhsScalar, (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), - (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, - _ActualLhsType, _ActualRhsType, Dest, BlockingType> GemmFunctor; + typedef internal::gemm_functor< + Scalar, Index, + internal::general_matrix_matrix_product< + Index, + LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), + RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), + (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, + ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor; - BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), true); + BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), true); - internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit); - } + internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)> + (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), Dest::Flags&RowMajorBit); + } }; +} // end namespace internal + } // end namespace Eigen #endif // EIGEN_GENERAL_MATRIX_MATRIX_H |