diff options
Diffstat (limited to 'Eigen/src/Core/Product.h')
-rw-r--r-- | Eigen/src/Core/Product.h | 185 |
1 files changed, 127 insertions, 58 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index a645ab6de..d63a7aa95 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -73,24 +73,9 @@ struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct> typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type; }; -/* Helper class to determine the type of the product, can be either: - * - NormalProduct - * - CacheFriendlyProduct - */ -template<typename Lhs, typename Rhs> struct ei_product_mode -{ - enum{ - - value = Lhs::MaxColsAtCompileTime == Dynamic - && ( Lhs::MaxRowsAtCompileTime == Dynamic - || Rhs::MaxColsAtCompileTime == Dynamic ) - && (!(Rhs::IsVectorAtCompileTime && (Lhs::Flags&RowMajorBit) && (!(Lhs::Flags&DirectAccessBit)))) - && (!(Lhs::IsVectorAtCompileTime && (!(Rhs::Flags&RowMajorBit)) && (!(Rhs::Flags&DirectAccessBit)))) - && (ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret) - ? CacheFriendlyProduct - : NormalProduct }; -}; - +/* Helper class to analyze the factors of a Product expression. + * In particular it allows to pop out operator-, scalar multiples, + * and conjugate */ template<typename XprType> struct ei_product_factor_traits { typedef typename ei_traits<XprType>::Scalar Scalar; @@ -98,11 +83,10 @@ template<typename XprType> struct ei_product_factor_traits enum { IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = false, - HasScalarMultiple = false, - Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess + ActualAccess = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess }; static inline const ActualXprType& extract(const XprType& x) { return x; } - static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); } + static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); } }; // pop conjugate @@ -117,8 +101,8 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = IsComplex }; - static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } - static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); } + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); } }; // pop scalar multiple @@ -128,11 +112,41 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw typedef ei_product_factor_traits<NestedXpr> Base; typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType; typedef typename Base::ActualXprType ActualXprType; - enum { - HasScalarMultiple = true - }; - static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } - static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; } + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) + { return x._functor().m_other * Base::extractScalarFactor(x._expression()); } +}; + +// pop opposite +template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> > + : ei_product_factor_traits<NestedXpr> +{ + typedef ei_product_factor_traits<NestedXpr> Base; + typedef CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> XprType; + typedef typename Base::ActualXprType ActualXprType; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) + { return - Base::extractScalarFactor(x._expression()); } +}; + +/* Helper class to determine the type of the product, can be either: + * - NormalProduct + * - CacheFriendlyProduct + */ +template<typename Lhs, typename Rhs> struct ei_product_mode +{ + typedef typename ei_product_factor_traits<Lhs>::ActualXprType ActualLhs; + typedef typename ei_product_factor_traits<Rhs>::ActualXprType ActualRhs; + enum{ + + value = Lhs::MaxColsAtCompileTime == Dynamic + && ( Lhs::MaxRowsAtCompileTime == Dynamic + || Rhs::MaxColsAtCompileTime == Dynamic ) + && (!(Rhs::IsVectorAtCompileTime && (Lhs::Flags&RowMajorBit) && (!(ActualLhs::Flags&DirectAccessBit)))) + && (!(Lhs::IsVectorAtCompileTime && (!(Rhs::Flags&RowMajorBit)) && (!(ActualRhs::Flags&DirectAccessBit)))) + && (ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret) + ? CacheFriendlyProduct + : NormalProduct }; }; /** \class Product @@ -552,11 +566,11 @@ void ei_cache_friendly_product( bool resRowMajor, Scalar* res, int resStride, Scalar alpha); -template<typename Scalar, typename RhsType> +template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename RhsType> static void ei_cache_friendly_product_colmajor_times_vector( int size, const Scalar* lhs, int lhsStride, const RhsType& rhs, Scalar* res, Scalar alpha); -template<typename Scalar, typename ResType> +template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename ResType> static void ei_cache_friendly_product_rowmajor_times_vector( const Scalar* lhs, int lhsStride, const Scalar* rhs, int rhsSize, ResType& res, Scalar alpha); @@ -572,10 +586,10 @@ static void ei_cache_friendly_product_rowmajor_times_vector( template<typename ProductType, int LhsRows = ei_traits<ProductType>::RowsAtCompileTime, int LhsOrder = int(ei_traits<ProductType>::LhsFlags)&RowMajorBit ? RowMajor : ColMajor, - int LhsHasDirectAccess = int(ei_traits<ProductType>::LhsFlags)&DirectAccessBit? HasDirectAccess : NoDirectAccess, + int LhsHasDirectAccess = ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested>::ActualAccess, int RhsCols = ei_traits<ProductType>::ColsAtCompileTime, int RhsOrder = int(ei_traits<ProductType>::RhsFlags)&RowMajorBit ? RowMajor : ColMajor, - int RhsHasDirectAccess = int(ei_traits<ProductType>::RhsFlags)&DirectAccessBit? HasDirectAccess : NoDirectAccess> + int RhsHasDirectAccess = ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested>::ActualAccess> struct ei_cache_friendly_product_selector { template<typename DestDerived> @@ -592,7 +606,6 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,NoDirectA template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { - // FIXME is it really used ? ei_assert(alpha==typename ProductType::Scalar(1)); const int size = product.rhs().rows(); for (int k=0; k<size; ++k) @@ -606,10 +619,21 @@ template<typename ProductType, int LhsRows, int RhsOrder, int RhsAccess> struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirectAccess,1,RhsOrder,RhsAccess> { typedef typename ProductType::Scalar Scalar; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits; + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; + template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { + const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs()); + const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs()); + + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs()) + * RhsProductTraits::extractScalarFactor(product.rhs()); + enum { EvalToRes = (ei_packet_traits<Scalar>::size==1) ||((DestDerived::Flags&ActualPacketAccessBit) && (!(DestDerived::Flags & RowMajorBit))) }; @@ -621,9 +645,12 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect _res = ei_aligned_stack_new(Scalar,res.size()); Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1> >(_res, res.size()) = res; } - ei_cache_friendly_product_colmajor_times_vector(res.size(), - &product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(), - product.rhs(), _res, alpha); + + ei_cache_friendly_product_colmajor_times_vector + <LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>( + res.size(), + &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.stride(), + actualRhs, _res, actualAlpha); if (!EvalToRes) { @@ -653,10 +680,21 @@ template<typename ProductType, int LhsOrder, int LhsAccess, int RhsCols> struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,RowMajor,HasDirectAccess> { typedef typename ProductType::Scalar Scalar; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits; + + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { + const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs()); + const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs()); + + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs()) + * RhsProductTraits::extractScalarFactor(product.rhs()); + enum { EvalToRes = (ei_packet_traits<Scalar>::size==1) ||((DestDerived::Flags & ActualPacketAccessBit) && (DestDerived::Flags & RowMajorBit)) }; @@ -668,9 +706,11 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo _res = ei_aligned_stack_new(Scalar, res.size()); Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()) = res; } - ei_cache_friendly_product_colmajor_times_vector(res.size(), - &product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(), - product.lhs().transpose(), _res, alpha); + + ei_cache_friendly_product_colmajor_times_vector + <RhsProductTraits::NeedToConjugate,LhsProductTraits::NeedToConjugate>(res.size(), + &actualRhs.const_cast_derived().coeffRef(0,0), actualRhs.stride(), + actualLhs.transpose(), _res, actualAlpha); if (!EvalToRes) { @@ -685,24 +725,39 @@ template<typename ProductType, int LhsRows, int RhsOrder, int RhsAccess> struct ei_cache_friendly_product_selector<ProductType,LhsRows,RowMajor,HasDirectAccess,1,RhsOrder,RhsAccess> { typedef typename ProductType::Scalar Scalar; - typedef typename ei_traits<ProductType>::_RhsNested Rhs; + + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits; + + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; + enum { - UseRhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (Rhs::Flags&ActualPacketAccessBit)) - && (!(Rhs::Flags & RowMajorBit)) }; + UseRhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (ActualRhsType::Flags&ActualPacketAccessBit)) + && (!(ActualRhsType::Flags & RowMajorBit)) }; template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { + const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs()); + const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs()); + + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs()) + * RhsProductTraits::extractScalarFactor(product.rhs()); + Scalar* EIGEN_RESTRICT _rhs; if (UseRhsDirectly) - _rhs = &product.rhs().const_cast_derived().coeffRef(0); + _rhs = &actualRhs.const_cast_derived().coeffRef(0); else { - _rhs = ei_aligned_stack_new(Scalar, product.rhs().size()); - Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1> >(_rhs, product.rhs().size()) = product.rhs(); + _rhs = ei_aligned_stack_new(Scalar, actualRhs.size()); + Map<Matrix<Scalar,ActualRhsType::SizeAtCompileTime,1> >(_rhs, actualRhs.size()) = actualRhs; } - ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(), - _rhs, product.rhs().size(), res, alpha); + + ei_cache_friendly_product_rowmajor_times_vector + <LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>( + &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.stride(), + _rhs, product.rhs().size(), res, actualAlpha); if (!UseRhsDirectly) ei_aligned_stack_delete(Scalar, _rhs, product.rhs().size()); } @@ -713,24 +768,39 @@ template<typename ProductType, int LhsOrder, int LhsAccess, int RhsCols> struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,ColMajor,HasDirectAccess> { typedef typename ProductType::Scalar Scalar; - typedef typename ei_traits<ProductType>::_LhsNested Lhs; + + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits; + typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits; + + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + typedef typename RhsProductTraits::ActualXprType ActualRhsType; + enum { - UseLhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (Lhs::Flags&ActualPacketAccessBit)) - && (Lhs::Flags & RowMajorBit) }; + UseLhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (ActualLhsType::Flags&ActualPacketAccessBit)) + && (ActualLhsType::Flags & RowMajorBit) }; template<typename DestDerived> inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) { + const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs()); + const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs()); + + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs()) + * RhsProductTraits::extractScalarFactor(product.rhs()); + Scalar* EIGEN_RESTRICT _lhs; if (UseLhsDirectly) - _lhs = &product.lhs().const_cast_derived().coeffRef(0); + _lhs = &actualLhs.const_cast_derived().coeffRef(0); else { - _lhs = ei_aligned_stack_new(Scalar, product.lhs().size()); - Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1> >(_lhs, product.lhs().size()) = product.lhs(); + _lhs = ei_aligned_stack_new(Scalar, actualLhs.size()); + Map<Matrix<Scalar,ActualLhsType::SizeAtCompileTime,1> >(_lhs, actualLhs.size()) = actualLhs; } - ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(), - _lhs, product.lhs().size(), res, alpha); + + ei_cache_friendly_product_rowmajor_times_vector + <RhsProductTraits::NeedToConjugate, LhsProductTraits::NeedToConjugate>( + &actualRhs.const_cast_derived().coeffRef(0,0), actualRhs.stride(), + _lhs, product.lhs().size(), res, actualAlpha); if(!UseLhsDirectly) ei_aligned_stack_delete(Scalar, _lhs, product.lhs().size()); } @@ -827,8 +897,8 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs); const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs); - Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs) - * RhsProductTraits::extractSalarFactor(m_rhs); + Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(m_lhs) + * RhsProductTraits::extractScalarFactor(m_rhs); typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy; typedef typename ei_unref<LhsCopy>::type _LhsCopy; @@ -837,7 +907,6 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& LhsCopy lhs(actualLhs); RhsCopy rhs(actualRhs); ei_cache_friendly_product<Scalar, -// LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate> ((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)), ((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))> ( |