aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2017-01-17 18:03:35 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2017-01-17 18:03:35 +0100
commit655ba783f8b2c9a8c3f4edb45e6db468aca22188 (patch)
treec71924c27b1b4746acd108e254e721c51f376085
parentc4fc2611ba34652f98b5e0ac9f817879bef8eed1 (diff)
Defer set-to-zero in triangular = product so that no aliasing issue occur in the common:
A.triangularView() = B*A.sefladjointView()*B.adjoint() case that used to work in 3.2.
-rw-r--r--Eigen/src/Core/TriangularMatrix.h9
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h14
-rw-r--r--test/product_mmtr.cpp13
3 files changed, 27 insertions, 9 deletions
diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h
index c35fb8083..667ef09dc 100644
--- a/Eigen/src/Core/TriangularMatrix.h
+++ b/Eigen/src/Core/TriangularMatrix.h
@@ -543,7 +543,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_Mat
template<typename ProductType>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TriangularViewType& _assignProduct(const ProductType& prod, const Scalar& alpha);
+ EIGEN_STRONG_INLINE TriangularViewType& _assignProduct(const ProductType& prod, const Scalar& alpha, bool beta);
};
/***************************************************************************
@@ -950,8 +950,7 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::assign_
if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
dst.resize(dstRows, dstCols);
- dst.setZero();
- dst._assignProduct(src, 1);
+ dst._assignProduct(src, 1, 0);
}
};
@@ -962,7 +961,7 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::add_ass
typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<Scalar,typename SrcXprType::Scalar> &)
{
- dst._assignProduct(src, 1);
+ dst._assignProduct(src, 1, 1);
}
};
@@ -973,7 +972,7 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::sub_ass
typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<Scalar,typename SrcXprType::Scalar> &)
{
- dst._assignProduct(src, -1);
+ dst._assignProduct(src, -1, 1);
}
};
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
index 29d6dc721..5cd2794a4 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
@@ -199,7 +199,7 @@ struct general_product_to_triangular_selector;
template<typename MatrixType, typename ProductType, int UpLo>
struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true>
{
- static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha)
+ static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta)
{
typedef typename MatrixType::Scalar Scalar;
@@ -217,6 +217,9 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true>
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
+ if(!beta)
+ mat.template triangularView<UpLo>().setZero();
+
enum {
StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor,
UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1,
@@ -244,7 +247,7 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true>
template<typename MatrixType, typename ProductType, int UpLo>
struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
{
- static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha)
+ static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta)
{
typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;
typedef internal::blas_traits<Lhs> LhsBlasTraits;
@@ -260,6 +263,9 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
+ if(!beta)
+ mat.template triangularView<UpLo>().setZero();
+
enum {
IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0,
LhsIsRowMajor = _ActualLhs::Flags&RowMajorBit ? 1 : 0,
@@ -286,11 +292,11 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
template<typename MatrixType, unsigned int UpLo>
template<typename ProductType>
-TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha)
+TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta)
{
eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols());
- general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha);
+ general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha, beta);
return derived();
}
diff --git a/test/product_mmtr.cpp b/test/product_mmtr.cpp
index b66529acd..f6e4bb1ae 100644
--- a/test/product_mmtr.cpp
+++ b/test/product_mmtr.cpp
@@ -62,6 +62,19 @@ template<typename Scalar> void mmtr(int size)
CHECK_MMTR(matc, Upper, -= (s*sqc).template triangularView<Upper>()*sqc);
CHECK_MMTR(matc, Lower, = (s*sqr).template triangularView<Lower>()*sqc);
CHECK_MMTR(matc, Upper, += (s*sqc).template triangularView<Lower>()*sqc);
+
+ // check aliasing
+ ref2 = ref1 = matc;
+ ref1 = sqc.adjoint() * matc * sqc;
+ ref2.template triangularView<Upper>() = ref1.template triangularView<Upper>();
+ matc.template triangularView<Upper>() = sqc.adjoint() * matc * sqc;
+ VERIFY_IS_APPROX(matc, ref2);
+
+ ref2 = ref1 = matc;
+ ref1 = sqc * matc * sqc.adjoint();
+ ref2.template triangularView<Lower>() = ref1.template triangularView<Lower>();
+ matc.template triangularView<Lower>() = sqc * matc * sqc.adjoint();
+ VERIFY_IS_APPROX(matc, ref2);
}
void test_product_mmtr()