diff options
author | Gael Guennebaud <g.gael@free.fr> | 2018-02-09 16:52:35 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2018-02-09 16:52:35 +0100 |
commit | 5deeb19e7bb19c67abeac0a6cfa26ad3d14e215b (patch) | |
tree | ef176c2771e7fa8e77e933fdb996b0e6c7ee71f0 /Eigen/src/Core/products/TriangularMatrixMatrix.h | |
parent | 12efc7d41b80259b996be5781bf596c249c90d3f (diff) |
bug #1517: fix triangular product with unit diagonal and nested scaling factor: (s*A).triangularView<UpperUnit>()*B
Diffstat (limited to 'Eigen/src/Core/products/TriangularMatrixMatrix.h')
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixMatrix.h | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h index 539b6c0c6..f784507e7 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h @@ -400,7 +400,9 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false> { template<typename Dest> static void run(Dest& dst, const Lhs &a_lhs, const Rhs &a_rhs, const typename Dest::Scalar& alpha) { - typedef typename Dest::Scalar Scalar; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + typedef typename Dest::Scalar Scalar; typedef internal::blas_traits<Lhs> LhsBlasTraits; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; @@ -412,8 +414,9 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false> typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs); typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs); - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) - * RhsBlasTraits::extractScalarFactor(a_rhs); + LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(a_lhs); + RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(a_rhs); + Scalar actualAlpha = alpha * lhs_alpha * rhs_alpha; typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,4> BlockingType; @@ -438,6 +441,21 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false> &dst.coeffRef(0,0), dst.outerStride(), // result info actualAlpha, blocking ); + + // Apply correction if the diagonal is unit and a scalar factor was nested: + if ((Mode&UnitDiag)==UnitDiag) + { + if (LhsIsTriangular && lhs_alpha!=LhsScalar(1)) + { + Index diagSize = (std::min)(lhs.rows(),lhs.cols()); + dst.topRows(diagSize) -= ((lhs_alpha-LhsScalar(1))*a_rhs).topRows(diagSize); + } + else if ((!LhsIsTriangular) && rhs_alpha!=RhsScalar(1)) + { + Index diagSize = (std::min)(rhs.rows(),rhs.cols()); + dst.leftCols(diagSize) -= (rhs_alpha-RhsScalar(1))*a_lhs.leftCols(diagSize); + } + } } }; |