diff options
author | Gael Guennebaud <g.gael@free.fr> | 2014-02-21 17:13:28 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2014-02-21 17:13:28 +0100 |
commit | d67548f345d01c69d9dbba5869d8cc0159e96464 (patch) | |
tree | 07188e26fe6758f5311a857ba5597719a6c57581 /Eigen/src/Core/GeneralProduct.h | |
parent | 6c7ab508117d84671054808c921980a4908efb20 (diff) |
Get rid of GeneralProduct<> for GemvProduct
Diffstat (limited to 'Eigen/src/Core/GeneralProduct.h')
-rw-r--r-- | Eigen/src/Core/GeneralProduct.h | 208 |
1 files changed, 193 insertions, 15 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h index 4c0fc7f63..57d5d3c38 100644 --- a/Eigen/src/Core/GeneralProduct.h +++ b/Eigen/src/Core/GeneralProduct.h @@ -342,16 +342,19 @@ class GeneralProduct<Lhs, Rhs, OuterProduct> */ namespace internal { +#ifndef EIGEN_TEST_EVALUATORS template<typename Lhs, typename Rhs> struct traits<GeneralProduct<Lhs,Rhs,GemvProduct> > : traits<ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs> > {}; +#endif template<int Side, int StorageOrder, bool BlasCompatible> struct gemv_selector; } // end namespace internal +#ifndef EIGEN_TEST_EVALUATORS template<typename Lhs, typename Rhs> class GeneralProduct<Lhs, Rhs, GemvProduct> : public ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs> @@ -378,24 +381,10 @@ class GeneralProduct<Lhs, Rhs, GemvProduct> bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)>::run(*this, dst, alpha); } }; +#endif namespace internal { -// The vector is on the left => transposition -template<int StorageOrder, bool BlasCompatible> -struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible> -{ - template<typename ProductType, typename Dest> - static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha) - { - Transpose<Dest> destT(dest); - enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor }; - gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible> - ::run(GeneralProduct<Transpose<const typename ProductType::_RhsNested>,Transpose<const typename ProductType::_LhsNested>, GemvProduct> - (prod.rhs().transpose(), prod.lhs().transpose()), destT, alpha); - } -}; - template<typename Scalar,int Size,int MaxSize,bool Cond> struct gemv_static_vector_if; template<typename Scalar,int Size,int MaxSize> @@ -432,6 +421,23 @@ struct gemv_static_vector_if<Scalar,Size,MaxSize,true> #endif }; +#ifndef EIGEN_TEST_EVALUATORS + +// The vector is on the left => transposition +template<int StorageOrder, bool BlasCompatible> +struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible> +{ + template<typename ProductType, typename Dest> + static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha) + { + Transpose<Dest> destT(dest); + enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor }; + gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible> + ::run(GeneralProduct<Transpose<const typename ProductType::_RhsNested>,Transpose<const typename ProductType::_LhsNested>, GemvProduct> + (prod.rhs().transpose(), prod.lhs().transpose()), destT, alpha); + } +}; + template<> struct gemv_selector<OnTheRight,ColMajor,true> { template<typename ProductType, typename Dest> @@ -582,6 +588,178 @@ template<> struct gemv_selector<OnTheRight,RowMajor,false> } }; +#else // EIGEN_TEST_EVALUATORS + +// The vector is on the left => transposition +template<int StorageOrder, bool BlasCompatible> +struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible> +{ + template<typename Lhs, typename Rhs, typename Dest> + static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) + { + Transpose<Dest> destT(dest); + enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor }; + gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible> + ::run(rhs.transpose(), lhs.transpose(), destT, alpha); + } +}; + +template<> struct gemv_selector<OnTheRight,ColMajor,true> +{ + template<typename Lhs, typename Rhs, typename Dest> + static inline void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) + { + typedef typename Dest::Index Index; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + typedef typename Dest::Scalar ResScalar; + typedef typename Dest::RealScalar RealScalar; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + + typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest; + + ActualLhsType actualLhs = LhsBlasTraits::extract(lhs); + ActualRhsType actualRhs = RhsBlasTraits::extract(rhs); + + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) + * RhsBlasTraits::extractScalarFactor(rhs); + + enum { + // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1 + // on, the other hand it is good for the cache to pack the vector anyways... + EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1, + ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex), + MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal + }; + + gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest; + + bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0)); + bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; + + RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha); + + ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), + evalToDest ? dest.data() : static_dest.data()); + + if(!evalToDest) + { + #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN + int size = dest.size(); + EIGEN_DENSE_STORAGE_CTOR_PLUGIN + #endif + if(!alphaIsCompatible) + { + MappedDest(actualDestPtr, dest.size()).setZero(); + compatibleAlpha = RhsScalar(1); + } + else + MappedDest(actualDestPtr, dest.size()) = dest; + } + + general_matrix_vector_product + <Index,LhsScalar,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run( + actualLhs.rows(), actualLhs.cols(), + actualLhs.data(), actualLhs.outerStride(), + actualRhs.data(), actualRhs.innerStride(), + actualDestPtr, 1, + compatibleAlpha); + + if (!evalToDest) + { + if(!alphaIsCompatible) + dest += actualAlpha * MappedDest(actualDestPtr, dest.size()); + else + dest = MappedDest(actualDestPtr, dest.size()); + } + } +}; + +template<> struct gemv_selector<OnTheRight,RowMajor,true> +{ + template<typename Lhs, typename Rhs, typename Dest> + static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) + { + typedef typename Dest::Index Index; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + typedef typename Dest::Scalar ResScalar; + typedef typename Dest::RealScalar RealScalar; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; + + typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs); + typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs); + + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) + * RhsBlasTraits::extractScalarFactor(rhs); + + enum { + // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1 + // on, the other hand it is good for the cache to pack the vector anyways... + DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 + }; + + gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs; + + ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(), + DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data()); + + if(!DirectlyUseRhs) + { + #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN + int size = actualRhs.size(); + EIGEN_DENSE_STORAGE_CTOR_PLUGIN + #endif + Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; + } + + general_matrix_vector_product + <Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run( + actualLhs.rows(), actualLhs.cols(), + actualLhs.data(), actualLhs.outerStride(), + actualRhsPtr, 1, + dest.data(), dest.innerStride(), + actualAlpha); + } +}; + +template<> struct gemv_selector<OnTheRight,ColMajor,false> +{ + template<typename Lhs, typename Rhs, typename Dest> + static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) + { + typedef typename Dest::Index Index; + // TODO makes sure dest is sequentially stored in memory, otherwise use a temp + const Index size = rhs.rows(); + for(Index k=0; k<size; ++k) + dest += (alpha*rhs.coeff(k)) * lhs.col(k); + } +}; + +template<> struct gemv_selector<OnTheRight,RowMajor,false> +{ + template<typename Lhs, typename Rhs, typename Dest> + static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) + { + typedef typename Dest::Index Index; + // TODO makes sure rhs is sequentially stored in memory, otherwise use a temp + const Index rows = dest.rows(); + for(Index i=0; i<rows; ++i) + dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(rhs.transpose())).sum(); + } +}; + +#endif // EIGEN_TEST_EVALUATORS + } // end namespace internal /*************************************************************************** |