diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-10-31 16:33:51 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-10-31 16:33:51 -0700 |
commit | 2dde63499c4ef836a0d9dfd443494d863ad62b16 (patch) | |
tree | 572f24e62fb03318adb527eaefcdb4593cb31b50 /Eigen/src/Core/GeneralProduct.h | |
parent | 7f2c6ed2fa35d7f83f0da83c8564b7bd5b01d232 (diff) |
Generalized the matrix vector product code.
Diffstat (limited to 'Eigen/src/Core/GeneralProduct.h')
-rw-r--r-- | Eigen/src/Core/GeneralProduct.h | 32 |
1 files changed, 18 insertions, 14 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h index 7179eb124..9d3d5562c 100644 --- a/Eigen/src/Core/GeneralProduct.h +++ b/Eigen/src/Core/GeneralProduct.h @@ -11,7 +11,7 @@ #ifndef EIGEN_GENERAL_PRODUCT_H #define EIGEN_GENERAL_PRODUCT_H -namespace Eigen { +namespace Eigen { /** \class GeneralProduct * \ingroup Core_Module @@ -257,7 +257,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct> : public ProductBase<GeneralProduct<Lhs,Rhs,OuterProduct>, Lhs, Rhs> { template<typename T> struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {}; - + public: EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct) @@ -266,7 +266,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct> EIGEN_STATIC_ASSERT((internal::is_same<typename Lhs::RealScalar, typename Rhs::RealScalar>::value), YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) } - + struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } }; struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } }; struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } }; @@ -277,12 +277,12 @@ class GeneralProduct<Lhs, Rhs, OuterProduct> dst.const_cast_derived() += m_scale * src; } }; - + template<typename Dest> inline void evalTo(Dest& dest) const { internal::outer_product_selector_run(*this, dest, set(), IsRowMajor<Dest>()); } - + template<typename Dest> inline void addTo(Dest& dest) const { internal::outer_product_selector_run(*this, dest, add(), IsRowMajor<Dest>()); @@ -436,12 +436,12 @@ template<> struct gemv_selector<OnTheRight,ColMajor,true> bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0)); bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; - + RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha); ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), evalToDest ? dest.data() : static_dest.data()); - + if(!evalToDest) { #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN @@ -457,11 +457,13 @@ template<> struct gemv_selector<OnTheRight,ColMajor,true> MappedDest(actualDestPtr, dest.size()) = dest; } + typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper; + typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper; general_matrix_vector_product - <Index,LhsScalar,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run( + <Index,LhsScalar,LhsMapper,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run( actualLhs.rows(), actualLhs.cols(), - actualLhs.data(), actualLhs.outerStride(), - actualRhs.data(), actualRhs.innerStride(), + LhsMapper(actualLhs.data(), actualLhs.outerStride()), + RhsMapper(actualRhs.data(), actualRhs.innerStride()), actualDestPtr, 1, compatibleAlpha); @@ -516,11 +518,13 @@ template<> struct gemv_selector<OnTheRight,RowMajor,true> Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; } + typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper; + typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper; general_matrix_vector_product - <Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run( + <Index,LhsScalar,LhsMapper,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run( actualLhs.rows(), actualLhs.cols(), - actualLhs.data(), actualLhs.outerStride(), - actualRhsPtr, 1, + LhsMapper(actualLhs.data(), actualLhs.outerStride()), + RhsMapper(actualRhsPtr, 1), dest.data(), dest.innerStride(), actualAlpha); } @@ -594,7 +598,7 @@ MatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const #ifdef EIGEN_DEBUG_PRODUCT internal::product_type<Derived,OtherDerived>::debug(); #endif - + return Product<Derived, OtherDerived>(derived(), other.derived()); } #else |