From 2dde63499c4ef836a0d9dfd443494d863ad62b16 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 31 Oct 2014 16:33:51 -0700 Subject: Generalized the matrix vector product code. --- Eigen/src/Core/GeneralProduct.h | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) (limited to 'Eigen/src/Core/GeneralProduct.h') diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h index 7179eb124..9d3d5562c 100644 --- a/Eigen/src/Core/GeneralProduct.h +++ b/Eigen/src/Core/GeneralProduct.h @@ -11,7 +11,7 @@ #ifndef EIGEN_GENERAL_PRODUCT_H #define EIGEN_GENERAL_PRODUCT_H -namespace Eigen { +namespace Eigen { /** \class GeneralProduct * \ingroup Core_Module @@ -257,7 +257,7 @@ class GeneralProduct : public ProductBase, Lhs, Rhs> { template struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {}; - + public: EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct) @@ -266,7 +266,7 @@ class GeneralProduct EIGEN_STATIC_ASSERT((internal::is_same::value), YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) } - + struct set { template void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } }; struct add { template void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } }; struct sub { template void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } }; @@ -277,12 +277,12 @@ class GeneralProduct dst.const_cast_derived() += m_scale * src; } }; - + template inline void evalTo(Dest& dest) const { internal::outer_product_selector_run(*this, dest, set(), IsRowMajor()); } - + template inline void addTo(Dest& dest) const { internal::outer_product_selector_run(*this, dest, add(), IsRowMajor()); @@ -436,12 +436,12 @@ template<> struct gemv_selector bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0)); bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; - + RhsScalar compatibleAlpha = get_factor::run(actualAlpha); ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), evalToDest ? dest.data() : static_dest.data()); - + if(!evalToDest) { #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN @@ -457,11 +457,13 @@ template<> struct gemv_selector MappedDest(actualDestPtr, dest.size()) = dest; } + typedef const_blas_data_mapper LhsMapper; + typedef const_blas_data_mapper RhsMapper; general_matrix_vector_product - ::run( + ::run( actualLhs.rows(), actualLhs.cols(), - actualLhs.data(), actualLhs.outerStride(), - actualRhs.data(), actualRhs.innerStride(), + LhsMapper(actualLhs.data(), actualLhs.outerStride()), + RhsMapper(actualRhs.data(), actualRhs.innerStride()), actualDestPtr, 1, compatibleAlpha); @@ -516,11 +518,13 @@ template<> struct gemv_selector Map(actualRhsPtr, actualRhs.size()) = actualRhs; } + typedef const_blas_data_mapper LhsMapper; + typedef const_blas_data_mapper RhsMapper; general_matrix_vector_product - ::run( + ::run( actualLhs.rows(), actualLhs.cols(), - actualLhs.data(), actualLhs.outerStride(), - actualRhsPtr, 1, + LhsMapper(actualLhs.data(), actualLhs.outerStride()), + RhsMapper(actualRhsPtr, 1), dest.data(), dest.innerStride(), actualAlpha); } @@ -594,7 +598,7 @@ MatrixBase::operator*(const MatrixBase &other) const #ifdef EIGEN_DEBUG_PRODUCT internal::product_type::debug(); #endif - + return Product(derived(), other.derived()); } #else -- cgit v1.2.3