From 04367447ac295d9818713f54b0a539efef7f0caa Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sun, 24 Feb 2013 23:05:42 +0100 Subject: Fix bug #496: generalize internal rank1_update implementation to accept uplo(A) += v * w and make A.triangularView() += v * w uses it. Update unit tests and blas interface respectively. --- .../Core/products/GeneralMatrixMatrixTriangular.h | 107 +++++++++++++++++---- Eigen/src/Core/products/SelfadjointProduct.h | 14 ++- blas/level2_cplx_impl.h | 4 +- blas/level2_real_impl.h | 4 +- test/product_syrk.cpp | 46 ++++++++- 5 files changed, 138 insertions(+), 37 deletions(-) diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h index 432d3a9dc..c4f83cd13 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h @@ -12,6 +12,9 @@ namespace Eigen { +template +struct selfadjoint_rank1_update; + namespace internal { /********************************************************************** @@ -180,31 +183,93 @@ struct tribb_kernel // high level API +template +struct general_product_to_triangular_selector; + + +template +struct general_product_to_triangular_selector +{ + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + { + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Index Index; + + typedef typename internal::remove_all::type Lhs; + typedef internal::blas_traits LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; + typedef typename internal::remove_all::type _ActualLhs; + typename internal::add_const_on_value_type::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + + typedef typename internal::remove_all::type Rhs; + typedef internal::blas_traits RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; + typedef typename internal::remove_all::type _ActualRhs; + typename internal::add_const_on_value_type::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + + enum { + StorageOrder = (internal::traits::Flags&RowMajorBit) ? RowMajor : ColMajor, + UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1, + UseRhsDirectly = _ActualRhs::InnerStrideAtCompileTime==1 + }; + + internal::gemv_static_vector_if static_lhs; + ei_declare_aligned_stack_constructed_variable(Scalar, actualLhsPtr, actualLhs.size(), + (UseLhsDirectly ? const_cast(actualLhs.data()) : static_lhs.data())); + if(!UseLhsDirectly) Map(actualLhsPtr, actualLhs.size()) = actualLhs; + + internal::gemv_static_vector_if static_rhs; + ei_declare_aligned_stack_constructed_variable(Scalar, actualRhsPtr, actualRhs.size(), + (UseRhsDirectly ? const_cast(actualRhs.data()) : static_rhs.data())); + if(!UseRhsDirectly) Map(actualRhsPtr, actualRhs.size()) = actualRhs; + + + selfadjoint_rank1_update::IsComplex, + RhsBlasTraits::NeedToConjugate && NumTraits::IsComplex> + ::run(actualLhs.size(), mat.data(), mat.outerStride(), actualLhsPtr, actualRhsPtr, actualAlpha); + } +}; + +template +struct general_product_to_triangular_selector +{ + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + { + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Index Index; + + typedef typename internal::remove_all::type Lhs; + typedef internal::blas_traits LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; + typedef typename internal::remove_all::type _ActualLhs; + typename internal::add_const_on_value_type::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + + typedef typename internal::remove_all::type Rhs; + typedef internal::blas_traits RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; + typedef typename internal::remove_all::type _ActualRhs; + typename internal::add_const_on_value_type::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + + internal::general_matrix_matrix_triangular_product + ::run(mat.cols(), actualLhs.cols(), + &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), + mat.data(), mat.outerStride(), actualAlpha); + } +}; + template template TriangularView& TriangularView::assignProduct(const ProductBase& prod, const Scalar& alpha) { - typedef typename internal::remove_all::type Lhs; - typedef internal::blas_traits LhsBlasTraits; - typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; - typedef typename internal::remove_all::type _ActualLhs; - typename internal::add_const_on_value_type::type actualLhs = LhsBlasTraits::extract(prod.lhs()); - - typedef typename internal::remove_all::type Rhs; - typedef internal::blas_traits RhsBlasTraits; - typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; - typedef typename internal::remove_all::type _ActualRhs; - typename internal::add_const_on_value_type::type actualRhs = RhsBlasTraits::extract(prod.rhs()); - - typename ProductDerived::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); - - internal::general_matrix_matrix_triangular_product - ::run(m_matrix.cols(), actualLhs.cols(), - &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), - const_cast(m_matrix.data()), m_matrix.outerStride(), actualAlpha); + general_product_to_triangular_selector::run(m_matrix.const_cast_derived(), prod.derived(), alpha); return *this; } diff --git a/Eigen/src/Core/products/SelfadjointProduct.h b/Eigen/src/Core/products/SelfadjointProduct.h index 6a55f3d77..302e0d841 100644 --- a/Eigen/src/Core/products/SelfadjointProduct.h +++ b/Eigen/src/Core/products/SelfadjointProduct.h @@ -18,21 +18,19 @@ namespace Eigen { -template -struct selfadjoint_rank1_update; template struct selfadjoint_rank1_update { - static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) + static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha) { internal::conj_if cj; typedef Map > OtherMap; - typedef typename internal::conditional::type ConjRhsType; + typedef typename internal::conditional::type ConjLhsType; for (Index i=0; i >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1))) - += (alpha * cj(vec[i])) * ConjRhsType(OtherMap(vec+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1))); + += (alpha * cj(vecY[i])) * ConjLhsType(OtherMap(vecX+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1))); } } }; @@ -40,9 +38,9 @@ struct selfadjoint_rank1_update template struct selfadjoint_rank1_update { - static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) + static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha) { - selfadjoint_rank1_update::run(size,mat,stride,vec,alpha); + selfadjoint_rank1_update::run(size,mat,stride,vecY,vecX,alpha); } }; @@ -78,7 +76,7 @@ struct selfadjoint_product_selector selfadjoint_rank1_update::IsComplex, (!OtherBlasTraits::NeedToConjugate) && NumTraits::IsComplex> - ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualAlpha); + ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualOtherPtr, actualAlpha); } }; diff --git a/blas/level2_cplx_impl.h b/blas/level2_cplx_impl.h index f52d384a9..ceed3e86d 100644 --- a/blas/level2_cplx_impl.h +++ b/blas/level2_cplx_impl.h @@ -216,7 +216,7 @@ int EIGEN_BLAS_FUNC(hpr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px */ int EIGEN_BLAS_FUNC(her)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *pa, int *lda) { - typedef void (*functype)(int, Scalar*, int, const Scalar*, Scalar); + typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar); static functype func[2]; static bool init = false; @@ -252,7 +252,7 @@ int EIGEN_BLAS_FUNC(her)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, if(code>=2 || func[code]==0) return 0; - func[code](*n, a, *lda, x_cpy, alpha); + func[code](*n, a, *lda, x_cpy, x_cpy, alpha); matrix(a,*n,*n,*lda).diagonal().imag().setZero(); diff --git a/blas/level2_real_impl.h b/blas/level2_real_impl.h index febf08d1f..842f0a066 100644 --- a/blas/level2_real_impl.h +++ b/blas/level2_real_impl.h @@ -85,7 +85,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, // init = true; // } - typedef void (*functype)(int, Scalar*, int, const Scalar*, Scalar); + typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar); static functype func[2]; static bool init = false; @@ -121,7 +121,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, if(code>=2 || func[code]==0) return 0; - func[code](*n, c, *ldc, x_cpy, alpha); + func[code](*n, c, *ldc, x_cpy, x_cpy, alpha); if(x_cpy!=x) delete[] x_cpy; diff --git a/test/product_syrk.cpp b/test/product_syrk.cpp index 5855c2181..ad233af70 100644 --- a/test/product_syrk.cpp +++ b/test/product_syrk.cpp @@ -14,6 +14,7 @@ template void syrk(const MatrixType& m) typedef typename MatrixType::Index Index; typedef typename MatrixType::Scalar Scalar; typedef typename NumTraits::Real RealScalar; + typedef Matrix RMatrixType; typedef Matrix Rhs1; typedef Matrix Rhs2; typedef Matrix Rhs3; @@ -22,10 +23,12 @@ template void syrk(const MatrixType& m) Index cols = m.cols(); MatrixType m1 = MatrixType::Random(rows, cols), - m2 = MatrixType::Random(rows, cols); + m2 = MatrixType::Random(rows, cols), + m3 = MatrixType::Random(rows, cols); + RMatrixType rm2 = MatrixType::Random(rows, cols); - Rhs1 rhs1 = Rhs1::Random(internal::random(1,320), cols); - Rhs2 rhs2 = Rhs2::Random(rows, internal::random(1,320)); + Rhs1 rhs1 = Rhs1::Random(internal::random(1,320), cols); Rhs1 rhs11 = Rhs1::Random(rhs1.rows(), cols); + Rhs2 rhs2 = Rhs2::Random(rows, internal::random(1,320)); Rhs2 rhs22 = Rhs2::Random(rows, rhs2.cols()); Rhs3 rhs3 = Rhs3::Random(internal::random(1,320), rows); Scalar s1 = internal::random(); @@ -35,19 +38,34 @@ template void syrk(const MatrixType& m) m2.setZero(); VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(rhs2,s1)._expression()), ((s1 * rhs2 * rhs2.adjoint()).eval().template triangularView().toDenseMatrix())); + m2.setZero(); + VERIFY_IS_APPROX(((m2.template triangularView() += s1 * rhs2 * rhs22.adjoint()).nestedExpression()), + ((s1 * rhs2 * rhs22.adjoint()).eval().template triangularView().toDenseMatrix())); + m2.setZero(); VERIFY_IS_APPROX(m2.template selfadjointView().rankUpdate(rhs2,s1)._expression(), (s1 * rhs2 * rhs2.adjoint()).eval().template triangularView().toDenseMatrix()); + m2.setZero(); + VERIFY_IS_APPROX((m2.template triangularView() += s1 * rhs22 * rhs2.adjoint()).nestedExpression(), + (s1 * rhs22 * rhs2.adjoint()).eval().template triangularView().toDenseMatrix()); + m2.setZero(); VERIFY_IS_APPROX(m2.template selfadjointView().rankUpdate(rhs1.adjoint(),s1)._expression(), (s1 * rhs1.adjoint() * rhs1).eval().template triangularView().toDenseMatrix()); - + m2.setZero(); + VERIFY_IS_APPROX((m2.template triangularView() += s1 * rhs11.adjoint() * rhs1).nestedExpression(), + (s1 * rhs11.adjoint() * rhs1).eval().template triangularView().toDenseMatrix()); + + m2.setZero(); VERIFY_IS_APPROX(m2.template selfadjointView().rankUpdate(rhs1.adjoint(),s1)._expression(), (s1 * rhs1.adjoint() * rhs1).eval().template triangularView().toDenseMatrix()); + VERIFY_IS_APPROX((m2.template triangularView() = s1 * rhs1.adjoint() * rhs11).nestedExpression(), + (s1 * rhs1.adjoint() * rhs11).eval().template triangularView().toDenseMatrix()); + m2.setZero(); VERIFY_IS_APPROX(m2.template selfadjointView().rankUpdate(rhs3.adjoint(),s1)._expression(), (s1 * rhs3.adjoint() * rhs3).eval().template triangularView().toDenseMatrix()); @@ -63,6 +81,15 @@ template void syrk(const MatrixType& m) m2.setZero(); VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.col(c),s1)._expression()), ((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView().toDenseMatrix())); + rm2.setZero(); + VERIFY_IS_APPROX((rm2.template selfadjointView().rankUpdate(m1.col(c),s1)._expression()), + ((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView().toDenseMatrix())); + m2.setZero(); + VERIFY_IS_APPROX((m2.template triangularView() += s1 * m3.col(c) * m1.col(c).adjoint()).nestedExpression(), + ((s1 * m3.col(c) * m1.col(c).adjoint()).eval().template triangularView().toDenseMatrix())); + rm2.setZero(); + VERIFY_IS_APPROX((rm2.template triangularView() += s1 * m1.col(c) * m3.col(c).adjoint()).nestedExpression(), + ((s1 * m1.col(c) * m3.col(c).adjoint()).eval().template triangularView().toDenseMatrix())); m2.setZero(); VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.col(c).conjugate(),s1)._expression()), @@ -72,9 +99,20 @@ template void syrk(const MatrixType& m) VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.col(c).conjugate(),s1)._expression()), ((s1 * m1.col(c).conjugate() * m1.col(c).conjugate().adjoint()).eval().template triangularView().toDenseMatrix())); + m2.setZero(); VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.row(c),s1)._expression()), ((s1 * m1.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView().toDenseMatrix())); + rm2.setZero(); + VERIFY_IS_APPROX((rm2.template selfadjointView().rankUpdate(m1.row(c),s1)._expression()), + ((s1 * m1.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView().toDenseMatrix())); + m2.setZero(); + VERIFY_IS_APPROX((m2.template triangularView() += s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).nestedExpression(), + ((s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView().toDenseMatrix())); + rm2.setZero(); + VERIFY_IS_APPROX((rm2.template triangularView() += s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).nestedExpression(), + ((s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView().toDenseMatrix())); + m2.setZero(); VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.row(c).adjoint(),s1)._expression()), -- cgit v1.2.3