diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-07-11 15:48:30 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-07-11 15:48:30 +0200 |
commit | ff96c94043d575e4d0dd477c1ed2487e33f79627 (patch) | |
tree | 5e9736916779fdacd431c2591a3ec1f77333e505 /Eigen/src/Core/Product.h | |
parent | 4161b8be6772f2b7338458c9932d7417797966bb (diff) |
mixing types in product step 2:
* pload* and pset1 are now templated on the packet type
* gemv routines are now embeded into a structure with
a consistent API with respect to gemm
* some configurations of vector * matrix and matrix * matrix works fine,
some need more work...
Diffstat (limited to 'Eigen/src/Core/Product.h')
-rw-r--r-- | Eigen/src/Core/Product.h | 63 |
1 files changed, 37 insertions, 26 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 139132c6b..edd79bd9a 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -282,10 +282,13 @@ class GeneralProduct<Lhs, Rhs, GemvProduct> public: EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct) + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) { - EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret), - YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) +// EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret), +// YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) } enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight }; @@ -295,7 +298,8 @@ class GeneralProduct<Lhs, Rhs, GemvProduct> { ei_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols()); ei_gemv_selector<Side,(int(MatrixType::Flags)&RowMajorBit) ? RowMajor : ColMajor, - bool(ei_blas_traits<MatrixType>::HasUsableDirectAccess)>::run(*this, dst, alpha); + bool(ei_blas_traits<MatrixType>::HasUsableDirectAccess) + /*&& ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret*/>::run(*this, dst, alpha); } }; @@ -319,43 +323,48 @@ template<> struct ei_gemv_selector<OnTheRight,ColMajor,true> template<typename ProductType, typename Dest> static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) { - typedef typename ProductType::Scalar Scalar; + typedef typename ProductType::Index Index; + typedef typename ProductType::LhsScalar LhsScalar; + typedef typename ProductType::RhsScalar RhsScalar; + typedef typename ProductType::Scalar ResScalar; typedef typename ProductType::ActualLhsType ActualLhsType; typedef typename ProductType::ActualRhsType ActualRhsType; typedef typename ProductType::LhsBlasTraits LhsBlasTraits; typedef typename ProductType::RhsBlasTraits RhsBlasTraits; - typedef Map<Matrix<Scalar,Dynamic,1>, Aligned> MappedDest; + typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest; ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs()); ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs()); - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) - * RhsBlasTraits::extractScalarFactor(prod.rhs()); + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) + * RhsBlasTraits::extractScalarFactor(prod.rhs()); enum { // FIXME find a way to allow an inner stride on the result if ei_packet_traits<Scalar>::size==1 EvalToDest = Dest::InnerStrideAtCompileTime==1 }; - Scalar* EIGEN_RESTRICT actualDest; + ResScalar* actualDest; if (EvalToDest) actualDest = &dest.coeffRef(0); else { - actualDest = ei_aligned_stack_new(Scalar,dest.size()); + actualDest = ei_aligned_stack_new(ResScalar,dest.size()); MappedDest(actualDest, dest.size()) = dest; } - ei_cache_friendly_product_colmajor_times_vector - <LhsBlasTraits::NeedToConjugate,RhsBlasTraits::NeedToConjugate>( - dest.size(), - &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.outerStride(), - actualRhs, actualDest, actualAlpha); + ei_general_matrix_vector_product + <Index,LhsScalar,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run( + actualLhs.rows(), actualLhs.cols(), + &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.outerStride(), + actualRhs, actualRhs.innerStride(), + actualDest, 1, + actualAlpha); if (!EvalToDest) { dest = MappedDest(actualDest, dest.size()); - ei_aligned_stack_delete(Scalar, actualDest, dest.size()); + ei_aligned_stack_delete(ResScalar, actualDest, dest.size()); } } }; @@ -365,7 +374,9 @@ template<> struct ei_gemv_selector<OnTheRight,RowMajor,true> template<typename ProductType, typename Dest> static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) { - typedef typename ProductType::Scalar Scalar; + typedef typename ProductType::LhsScalar LhsScalar; + typedef typename ProductType::RhsScalar RhsScalar; + typedef typename ProductType::Scalar ResScalar; typedef typename ProductType::Index Index; typedef typename ProductType::ActualLhsType ActualLhsType; typedef typename ProductType::ActualRhsType ActualRhsType; @@ -376,34 +387,34 @@ template<> struct ei_gemv_selector<OnTheRight,RowMajor,true> ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs()); ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs()); - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) - * RhsBlasTraits::extractScalarFactor(prod.rhs()); + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) + * RhsBlasTraits::extractScalarFactor(prod.rhs()); enum { // FIXME I think here we really have to check for ei_packet_traits<Scalar>::size==1 // because in this case it is fine to have an inner stride - DirectlyUseRhs = ((ei_packet_traits<Scalar>::size==1) || (_ActualRhsType::Flags&ActualPacketAccessBit)) + DirectlyUseRhs = ((ei_packet_traits<RhsScalar>::size==1) || (_ActualRhsType::Flags&ActualPacketAccessBit)) && (!(_ActualRhsType::Flags & RowMajorBit)) }; - Scalar* EIGEN_RESTRICT rhs_data; + RhsScalar* rhs_data; if (DirectlyUseRhs) - rhs_data = reinterpret_cast<Scalar* EIGEN_RESTRICT>(&actualRhs.const_cast_derived().coeffRef(0)); + rhs_data = &actualRhs.const_cast_derived().coeffRef(0); else { - rhs_data = ei_aligned_stack_new(Scalar, actualRhs.size()); - Map<typename _ActualRhsType::PlainObject>(reinterpret_cast<Scalar*>(rhs_data), actualRhs.size()) = actualRhs; + rhs_data = ei_aligned_stack_new(RhsScalar, actualRhs.size()); + Map<typename _ActualRhsType::PlainObject>(rhs_data, actualRhs.size()) = actualRhs; } - ei_cache_friendly_product_rowmajor_times_vector - <LhsBlasTraits::NeedToConjugate,RhsBlasTraits::NeedToConjugate, Scalar, Index>( + ei_general_matrix_vector_product + <Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run( actualLhs.rows(), actualLhs.cols(), &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.outerStride(), rhs_data, 1, &dest.coeffRef(0,0), dest.innerStride(), actualAlpha); - if (!DirectlyUseRhs) ei_aligned_stack_delete(Scalar, rhs_data, prod.rhs().size()); + if (!DirectlyUseRhs) ei_aligned_stack_delete(RhsScalar, rhs_data, prod.rhs().size()); } }; |