diff options
Diffstat (limited to 'Eigen/src')
-rw-r--r-- | Eigen/src/Core/ProductBase.h | 2 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixVector.h | 186 |
2 files changed, 163 insertions, 25 deletions
diff --git a/Eigen/src/Core/ProductBase.h b/Eigen/src/Core/ProductBase.h index 1f2b373cd..287ea554f 100644 --- a/Eigen/src/Core/ProductBase.h +++ b/Eigen/src/Core/ProductBase.h @@ -84,12 +84,14 @@ class ProductBase : public MatrixBase<Derived> typedef internal::blas_traits<_LhsNested> LhsBlasTraits; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; typedef typename internal::remove_all<ActualLhsType>::type _ActualLhsType; + typedef typename internal::traits<Lhs>::Scalar LhsScalar; typedef typename Rhs::Nested RhsNested; typedef typename internal::remove_all<RhsNested>::type _RhsNested; typedef internal::blas_traits<_RhsNested> RhsBlasTraits; typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; typedef typename internal::remove_all<ActualRhsType>::type _ActualRhsType; + typedef typename internal::traits<Rhs>::Scalar RhsScalar; // Diagonal of a product: no need to evaluate the arguments because they are going to be evaluated only once typedef CoeffBasedProduct<LhsNested, RhsNested, 0> FullyLazyCoeffBaseProductType; diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h index c1f64dcea..23aa52ade 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector.h +++ b/Eigen/src/Core/products/TriangularMatrixVector.h @@ -152,6 +152,10 @@ struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> > : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> > {}; + +template<int StorageOrder> +struct trmv_selector; + } // end namespace internal template<int Mode, typename Lhs, typename Rhs> @@ -165,20 +169,8 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); - - const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); - const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); - - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) - * RhsBlasTraits::extractScalarFactor(m_rhs); - - internal::product_triangular_matrix_vector - <Index,Mode, - typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate, - typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate, - (int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor> - ::run(lhs.rows(),lhs.cols(),lhs.data(),lhs.outerStride(),rhs.data(),rhs.innerStride(), - dst.data(),dst.innerStride(),actualAlpha); + + internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha); } }; @@ -192,23 +184,167 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false> template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const { - eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); + + typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose; + Transpose<Dest> dstT(dst); + internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run( + TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha); + } +}; + +namespace internal { + +// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same. + +template<> struct trmv_selector<ColMajor> +{ + template<int Mode, typename Lhs, typename Rhs, typename Dest> + static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha) + { + typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; + typedef typename ProductType::Index Index; + typedef typename ProductType::LhsScalar LhsScalar; + typedef typename ProductType::RhsScalar RhsScalar; + typedef typename ProductType::Scalar ResScalar; + typedef typename ProductType::RealScalar RealScalar; + typedef typename ProductType::ActualLhsType ActualLhsType; + typedef typename ProductType::ActualRhsType ActualRhsType; + typedef typename ProductType::LhsBlasTraits LhsBlasTraits; + typedef typename ProductType::RhsBlasTraits RhsBlasTraits; + typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest; + + const ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs()); + const ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs()); + + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) + * RhsBlasTraits::extractScalarFactor(prod.rhs()); - const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); - const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); + enum { + // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1 + // on, the other hand it is good for the cache to pack the vector anyways... + EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1, + ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex), + MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal + }; - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) - * RhsBlasTraits::extractScalarFactor(m_rhs); + gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest; + bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0)); + bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; + + RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha); + + ResScalar* actualDestPtr; + bool freeDestPtr = false; + if (evalToDest) + { + actualDestPtr = dest.data(); + } + else + { + #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN + int size = dest.size(); + EIGEN_DENSE_STORAGE_CTOR_PLUGIN + #endif + if((actualDestPtr = static_dest.data())==0) + { + freeDestPtr = true; + actualDestPtr = ei_aligned_stack_new(ResScalar,dest.size()); + } + if(!alphaIsCompatible) + { + MappedDest(actualDestPtr, dest.size()).setZero(); + compatibleAlpha = RhsScalar(1); + } + else + MappedDest(actualDestPtr, dest.size()) = dest; + } + internal::product_triangular_matrix_vector - <Index,(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower), - typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate, - typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate, - (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor> - ::run(rhs.rows(),rhs.cols(),rhs.data(),rhs.outerStride(),lhs.data(),lhs.innerStride(), - dst.data(),dst.innerStride(),actualAlpha); + <Index,Mode, + LhsScalar, LhsBlasTraits::NeedToConjugate, + RhsScalar, RhsBlasTraits::NeedToConjugate, + ColMajor> + ::run(actualLhs.rows(),actualLhs.cols(), + actualLhs.data(),actualLhs.outerStride(), + actualRhs.data(),actualRhs.innerStride(), + actualDestPtr,1,compatibleAlpha); + + if (!evalToDest) + { + if(!alphaIsCompatible) + dest += actualAlpha * MappedDest(actualDestPtr, dest.size()); + else + dest = MappedDest(actualDestPtr, dest.size()); + if(freeDestPtr) ei_aligned_stack_delete(ResScalar, actualDestPtr, dest.size()); + } } }; +template<> struct trmv_selector<RowMajor> +{ + template<int Mode, typename Lhs, typename Rhs, typename Dest> + static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha) + { + typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; + typedef typename ProductType::LhsScalar LhsScalar; + typedef typename ProductType::RhsScalar RhsScalar; + typedef typename ProductType::Scalar ResScalar; + typedef typename ProductType::Index Index; + typedef typename ProductType::ActualLhsType ActualLhsType; + typedef typename ProductType::ActualRhsType ActualRhsType; + typedef typename ProductType::_ActualRhsType _ActualRhsType; + typedef typename ProductType::LhsBlasTraits LhsBlasTraits; + typedef typename ProductType::RhsBlasTraits RhsBlasTraits; + + typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) + * RhsBlasTraits::extractScalarFactor(prod.rhs()); + + enum { + DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1 + }; + + gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs; + + RhsScalar* actualRhsPtr; + bool freeRhsPtr = false; + if (DirectlyUseRhs) + { + actualRhsPtr = const_cast<RhsScalar*>(actualRhs.data()); + } + else + { + #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN + int size = actualRhs.size(); + EIGEN_DENSE_STORAGE_CTOR_PLUGIN + #endif + if((actualRhsPtr = static_rhs.data())==0) + { + freeRhsPtr = true; + actualRhsPtr = ei_aligned_stack_new(RhsScalar, actualRhs.size()); + } + Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; + } + + internal::product_triangular_matrix_vector + <Index,Mode, + LhsScalar, LhsBlasTraits::NeedToConjugate, + RhsScalar, RhsBlasTraits::NeedToConjugate, + RowMajor> + ::run(actualLhs.rows(),actualLhs.cols(), + actualLhs.data(),actualLhs.outerStride(), + actualRhsPtr,1, + dest.data(),dest.innerStride(), + actualAlpha); + + if((!DirectlyUseRhs) && freeRhsPtr) ei_aligned_stack_delete(RhsScalar, actualRhsPtr, prod.rhs().size()); + } +}; + +} // end namespace internal + #endif // EIGEN_TRIANGULARMATRIXVECTOR_H |