aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2017-06-09 14:38:04 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2017-06-09 14:38:04 +0200
commitba5cab576a7615af12389ff159c6ed57b5312d5e (patch)
treee4200fcea398221324069d837cdd68a2863dbd96
parent90168c003d80236e935589d56f9fbd8e9c7aaa75 (diff)
bug #1405: enable StrictlyLower/StrictlyUpper triangularView as the destination of matrix*matrix products.
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h13
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h2
-rw-r--r--test/product_mmtr.cpp11
3 files changed, 19 insertions, 7 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
index ad38bcf51..e436c50a4 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
@@ -269,10 +269,13 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
enum {
IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0,
LhsIsRowMajor = _ActualLhs::Flags&RowMajorBit ? 1 : 0,
- RhsIsRowMajor = _ActualRhs::Flags&RowMajorBit ? 1 : 0
+ RhsIsRowMajor = _ActualRhs::Flags&RowMajorBit ? 1 : 0,
+ SkipDiag = (UpLo&(UnitDiag|ZeroDiag))!=0
};
Index size = mat.cols();
+ if(SkipDiag)
+ size--;
Index depth = actualLhs.cols();
typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor,typename Lhs::Scalar,typename Rhs::Scalar,
@@ -283,10 +286,11 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
internal::general_matrix_matrix_triangular_product<Index,
typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
- IsRowMajor ? RowMajor : ColMajor, UpLo>
+ IsRowMajor ? RowMajor : ColMajor, UpLo&(Lower|Upper)>
::run(size, depth,
- &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
- mat.data(), mat.outerStride(), actualAlpha, blocking);
+ &actualLhs.coeffRef(SkipDiag&&(UpLo&Lower)==Lower ? 1 : 0,0), actualLhs.outerStride(),
+ &actualRhs.coeffRef(0,SkipDiag&&(UpLo&Upper)==Upper ? 1 : 0), actualRhs.outerStride(),
+ mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? 1 : mat.outerStride() ) : 0), mat.outerStride(), actualAlpha, blocking);
}
};
@@ -294,6 +298,7 @@ template<typename MatrixType, unsigned int UpLo>
template<typename ProductType>
EIGEN_DEVICE_FUNC TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta)
{
+ EIGEN_STATIC_ASSERT((UpLo&UnitDiag)==0, WRITING_TO_TRIANGULAR_PART_WITH_UNIT_DIAGONAL_IS_NOT_SUPPORTED);
eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols());
general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha, beta);
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h
index 5b7c15cca..41e18ff07 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h
@@ -52,7 +52,7 @@ struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,Con
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) \
{ \
- if (lhs==rhs) { \
+ if ( lhs==rhs && ((UpLo&(Lower|Upper)==UpLo)) ) { \
general_matrix_matrix_rankupdate<Index,Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,UpLo> \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
} else { \
diff --git a/test/product_mmtr.cpp b/test/product_mmtr.cpp
index f6e4bb1ae..d3e24b012 100644
--- a/test/product_mmtr.cpp
+++ b/test/product_mmtr.cpp
@@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
-// Copyright (C) 2010 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2010-2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -10,12 +10,19 @@
#include "main.h"
#define CHECK_MMTR(DEST, TRI, OP) { \
+ ref3 = DEST; \
ref2 = ref1 = DEST; \
DEST.template triangularView<TRI>() OP; \
ref1 OP; \
ref2.template triangularView<TRI>() \
= ref1.template triangularView<TRI>(); \
VERIFY_IS_APPROX(DEST,ref2); \
+ \
+ DEST = ref3; \
+ ref3 = ref2; \
+ ref3.diagonal() = DEST.diagonal(); \
+ DEST.template triangularView<TRI|ZeroDiag>() OP; \
+ VERIFY_IS_APPROX(DEST,ref3); \
}
template<typename Scalar> void mmtr(int size)
@@ -27,7 +34,7 @@ template<typename Scalar> void mmtr(int size)
MatrixColMaj matc = MatrixColMaj::Zero(size, size);
MatrixRowMaj matr = MatrixRowMaj::Zero(size, size);
- MatrixColMaj ref1(size, size), ref2(size, size);
+ MatrixColMaj ref1(size, size), ref2(size, size), ref3(size,size);
MatrixColMaj soc(size,othersize); soc.setRandom();
MatrixColMaj osc(othersize,size); osc.setRandom();