From 04367447ac295d9818713f54b0a539efef7f0caa Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sun, 24 Feb 2013 23:05:42 +0100 Subject: Fix bug #496: generalize internal rank1_update implementation to accept uplo(A) += v * w and make A.triangularView() += v * w uses it. Update unit tests and blas interface respectively. --- .../Core/products/GeneralMatrixMatrixTriangular.h | 107 +++++++++++++++++---- Eigen/src/Core/products/SelfadjointProduct.h | 14 ++- 2 files changed, 92 insertions(+), 29 deletions(-) (limited to 'Eigen/src/Core') diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h index 432d3a9dc..c4f83cd13 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h @@ -12,6 +12,9 @@ namespace Eigen { +template +struct selfadjoint_rank1_update; + namespace internal { /********************************************************************** @@ -180,31 +183,93 @@ struct tribb_kernel // high level API +template +struct general_product_to_triangular_selector; + + +template +struct general_product_to_triangular_selector +{ + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + { + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Index Index; + + typedef typename internal::remove_all::type Lhs; + typedef internal::blas_traits LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; + typedef typename internal::remove_all::type _ActualLhs; + typename internal::add_const_on_value_type::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + + typedef typename internal::remove_all::type Rhs; + typedef internal::blas_traits RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; + typedef typename internal::remove_all::type _ActualRhs; + typename internal::add_const_on_value_type::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + + enum { + StorageOrder = (internal::traits::Flags&RowMajorBit) ? RowMajor : ColMajor, + UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1, + UseRhsDirectly = _ActualRhs::InnerStrideAtCompileTime==1 + }; + + internal::gemv_static_vector_if static_lhs; + ei_declare_aligned_stack_constructed_variable(Scalar, actualLhsPtr, actualLhs.size(), + (UseLhsDirectly ? const_cast(actualLhs.data()) : static_lhs.data())); + if(!UseLhsDirectly) Map(actualLhsPtr, actualLhs.size()) = actualLhs; + + internal::gemv_static_vector_if static_rhs; + ei_declare_aligned_stack_constructed_variable(Scalar, actualRhsPtr, actualRhs.size(), + (UseRhsDirectly ? const_cast(actualRhs.data()) : static_rhs.data())); + if(!UseRhsDirectly) Map(actualRhsPtr, actualRhs.size()) = actualRhs; + + + selfadjoint_rank1_update::IsComplex, + RhsBlasTraits::NeedToConjugate && NumTraits::IsComplex> + ::run(actualLhs.size(), mat.data(), mat.outerStride(), actualLhsPtr, actualRhsPtr, actualAlpha); + } +}; + +template +struct general_product_to_triangular_selector +{ + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + { + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Index Index; + + typedef typename internal::remove_all::type Lhs; + typedef internal::blas_traits LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; + typedef typename internal::remove_all::type _ActualLhs; + typename internal::add_const_on_value_type::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + + typedef typename internal::remove_all::type Rhs; + typedef internal::blas_traits RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; + typedef typename internal::remove_all::type _ActualRhs; + typename internal::add_const_on_value_type::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + + internal::general_matrix_matrix_triangular_product + ::run(mat.cols(), actualLhs.cols(), + &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), + mat.data(), mat.outerStride(), actualAlpha); + } +}; + template template TriangularView& TriangularView::assignProduct(const ProductBase& prod, const Scalar& alpha) { - typedef typename internal::remove_all::type Lhs; - typedef internal::blas_traits LhsBlasTraits; - typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; - typedef typename internal::remove_all::type _ActualLhs; - typename internal::add_const_on_value_type::type actualLhs = LhsBlasTraits::extract(prod.lhs()); - - typedef typename internal::remove_all::type Rhs; - typedef internal::blas_traits RhsBlasTraits; - typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; - typedef typename internal::remove_all::type _ActualRhs; - typename internal::add_const_on_value_type::type actualRhs = RhsBlasTraits::extract(prod.rhs()); - - typename ProductDerived::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); - - internal::general_matrix_matrix_triangular_product - ::run(m_matrix.cols(), actualLhs.cols(), - &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), - const_cast(m_matrix.data()), m_matrix.outerStride(), actualAlpha); + general_product_to_triangular_selector::run(m_matrix.const_cast_derived(), prod.derived(), alpha); return *this; } diff --git a/Eigen/src/Core/products/SelfadjointProduct.h b/Eigen/src/Core/products/SelfadjointProduct.h index 6a55f3d77..302e0d841 100644 --- a/Eigen/src/Core/products/SelfadjointProduct.h +++ b/Eigen/src/Core/products/SelfadjointProduct.h @@ -18,21 +18,19 @@ namespace Eigen { -template -struct selfadjoint_rank1_update; template struct selfadjoint_rank1_update { - static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) + static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha) { internal::conj_if cj; typedef Map > OtherMap; - typedef typename internal::conditional::type ConjRhsType; + typedef typename internal::conditional::type ConjLhsType; for (Index i=0; i >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1))) - += (alpha * cj(vec[i])) * ConjRhsType(OtherMap(vec+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1))); + += (alpha * cj(vecY[i])) * ConjLhsType(OtherMap(vecX+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1))); } } }; @@ -40,9 +38,9 @@ struct selfadjoint_rank1_update template struct selfadjoint_rank1_update { - static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) + static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha) { - selfadjoint_rank1_update::run(size,mat,stride,vec,alpha); + selfadjoint_rank1_update::run(size,mat,stride,vecY,vecX,alpha); } }; @@ -78,7 +76,7 @@ struct selfadjoint_product_selector selfadjoint_rank1_update::IsComplex, (!OtherBlasTraits::NeedToConjugate) && NumTraits::IsComplex> - ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualAlpha); + ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualOtherPtr, actualAlpha); } }; -- cgit v1.2.3