diff options
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h | 107 | ||||
-rw-r--r-- | Eigen/src/Core/products/SelfadjointProduct.h | 14 |
2 files changed, 92 insertions, 29 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h index 432d3a9dc..c4f83cd13 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h @@ -12,6 +12,9 @@ namespace Eigen { +template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs> +struct selfadjoint_rank1_update; + namespace internal { /********************************************************************** @@ -180,31 +183,93 @@ struct tribb_kernel // high level API +template<typename MatrixType, typename ProductType, int UpLo, bool IsOuterProduct> +struct general_product_to_triangular_selector; + + +template<typename MatrixType, typename ProductType, int UpLo> +struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true> +{ + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + { + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Index Index; + + typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs; + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; + typedef typename internal::remove_all<ActualLhs>::type _ActualLhs; + typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + + typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; + typedef typename internal::remove_all<ActualRhs>::type _ActualRhs; + typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + + enum { + StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor, + UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1, + UseRhsDirectly = _ActualRhs::InnerStrideAtCompileTime==1 + }; + + internal::gemv_static_vector_if<Scalar,Lhs::SizeAtCompileTime,Lhs::MaxSizeAtCompileTime,!UseLhsDirectly> static_lhs; + ei_declare_aligned_stack_constructed_variable(Scalar, actualLhsPtr, actualLhs.size(), + (UseLhsDirectly ? const_cast<Scalar*>(actualLhs.data()) : static_lhs.data())); + if(!UseLhsDirectly) Map<typename _ActualLhs::PlainObject>(actualLhsPtr, actualLhs.size()) = actualLhs; + + internal::gemv_static_vector_if<Scalar,Rhs::SizeAtCompileTime,Rhs::MaxSizeAtCompileTime,!UseRhsDirectly> static_rhs; + ei_declare_aligned_stack_constructed_variable(Scalar, actualRhsPtr, actualRhs.size(), + (UseRhsDirectly ? const_cast<Scalar*>(actualRhs.data()) : static_rhs.data())); + if(!UseRhsDirectly) Map<typename _ActualRhs::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; + + + selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo, + LhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, + RhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex> + ::run(actualLhs.size(), mat.data(), mat.outerStride(), actualLhsPtr, actualRhsPtr, actualAlpha); + } +}; + +template<typename MatrixType, typename ProductType, int UpLo> +struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false> +{ + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + { + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Index Index; + + typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs; + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; + typedef typename internal::remove_all<ActualLhs>::type _ActualLhs; + typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + + typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; + typedef typename internal::remove_all<ActualRhs>::type _ActualRhs; + typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + + internal::general_matrix_matrix_triangular_product<Index, + typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, + typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, + MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> + ::run(mat.cols(), actualLhs.cols(), + &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), + mat.data(), mat.outerStride(), actualAlpha); + } +}; + template<typename MatrixType, unsigned int UpLo> template<typename ProductDerived, typename _Lhs, typename _Rhs> TriangularView<MatrixType,UpLo>& TriangularView<MatrixType,UpLo>::assignProduct(const ProductBase<ProductDerived, _Lhs,_Rhs>& prod, const Scalar& alpha) { - typedef typename internal::remove_all<typename ProductDerived::LhsNested>::type Lhs; - typedef internal::blas_traits<Lhs> LhsBlasTraits; - typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; - typedef typename internal::remove_all<ActualLhs>::type _ActualLhs; - typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); - - typedef typename internal::remove_all<typename ProductDerived::RhsNested>::type Rhs; - typedef internal::blas_traits<Rhs> RhsBlasTraits; - typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; - typedef typename internal::remove_all<ActualRhs>::type _ActualRhs; - typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); - - typename ProductDerived::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); - - internal::general_matrix_matrix_triangular_product<Index, - typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, - typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, - MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> - ::run(m_matrix.cols(), actualLhs.cols(), - &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), - const_cast<Scalar*>(m_matrix.data()), m_matrix.outerStride(), actualAlpha); + general_product_to_triangular_selector<MatrixType, ProductDerived, UpLo, (_Lhs::ColsAtCompileTime==1) || (_Rhs::RowsAtCompileTime==1)>::run(m_matrix.const_cast_derived(), prod.derived(), alpha); return *this; } diff --git a/Eigen/src/Core/products/SelfadjointProduct.h b/Eigen/src/Core/products/SelfadjointProduct.h index 6a55f3d77..302e0d841 100644 --- a/Eigen/src/Core/products/SelfadjointProduct.h +++ b/Eigen/src/Core/products/SelfadjointProduct.h @@ -18,21 +18,19 @@ namespace Eigen { -template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs> -struct selfadjoint_rank1_update; template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs> struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs> { - static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) + static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha) { internal::conj_if<ConjRhs> cj; typedef Map<const Matrix<Scalar,Dynamic,1> > OtherMap; - typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjRhsType; + typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjLhsType; for (Index i=0; i<size; ++i) { Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1))) - += (alpha * cj(vec[i])) * ConjRhsType(OtherMap(vec+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1))); + += (alpha * cj(vecY[i])) * ConjLhsType(OtherMap(vecX+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1))); } } }; @@ -40,9 +38,9 @@ struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs> template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs> struct selfadjoint_rank1_update<Scalar,Index,RowMajor,UpLo,ConjLhs,ConjRhs> { - static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) + static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha) { - selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vec,alpha); + selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vecY,vecX,alpha); } }; @@ -78,7 +76,7 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true> selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex> - ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualAlpha); + ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualOtherPtr, actualAlpha); } }; |