aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2019-09-10 16:25:24 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2019-09-10 16:25:24 +0200
commitea0d5dc956c1268dd91ce636d8fd5e07225acb06 (patch)
tree624180571f205e23b72d4a8c87455ccfad30ebf5 /Eigen/src/Core/products
parent17226100c5e56d1c6064560390a4a6e16677bb45 (diff)
bug #1741: fix C.noalias() = A*C; with C.innerStride()!=1
Diffstat (limited to 'Eigen/src/Core/products')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h31
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h6
-rw-r--r--Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h8
3 files changed, 25 insertions, 20 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 90c9c4647..508c05c97 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -20,8 +20,9 @@ template<typename _LhsScalar, typename _RhsScalar> class level3_blocking;
template<
typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
- typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
-struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor>
+ typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride>
{
typedef gebp_traits<RhsScalar,LhsScalar> Traits;
@@ -30,7 +31,7 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
Index rows, Index cols, Index depth,
const LhsScalar* lhs, Index lhsStride,
const RhsScalar* rhs, Index rhsStride,
- ResScalar* res, Index resStride,
+ ResScalar* res, Index resIncr, Index resStride,
ResScalar alpha,
level3_blocking<RhsScalar,LhsScalar>& blocking,
GemmParallelInfo<Index>* info = 0)
@@ -39,8 +40,8 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
general_matrix_matrix_product<Index,
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
- ColMajor>
- ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info);
+ ColMajor,ResInnerStride>
+ ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking,info);
}
};
@@ -49,8 +50,9 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
template<
typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
- typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
-struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor>
+ typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride>
{
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
@@ -59,17 +61,17 @@ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScala
static void run(Index rows, Index cols, Index depth,
const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsStride,
- ResScalar* _res, Index resStride,
+ ResScalar* _res, Index resIncr, Index resStride,
ResScalar alpha,
level3_blocking<LhsScalar,RhsScalar>& blocking,
GemmParallelInfo<Index>* info = 0)
{
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
- typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
- LhsMapper lhs(_lhs,lhsStride);
- RhsMapper rhs(_rhs,rhsStride);
- ResMapper res(_res, resStride);
+ typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor,Unaligned,ResInnerStride> ResMapper;
+ LhsMapper lhs(_lhs, lhsStride);
+ RhsMapper rhs(_rhs, rhsStride);
+ ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
@@ -228,7 +230,7 @@ struct gemm_functor
Gemm::run(rows, cols, m_lhs.cols(),
&m_lhs.coeffRef(row,0), m_lhs.outerStride(),
&m_rhs.coeffRef(0,col), m_rhs.outerStride(),
- (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(),
+ (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.innerStride(), m_dest.outerStride(),
m_actualAlpha, m_blocking, info);
}
@@ -498,7 +500,8 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
Index,
LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
- (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
+ (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,
+ Dest::InnerStrideAtCompileTime>,
ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
index b0f6b0d5b..71abf4013 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
@@ -51,20 +51,22 @@ template< \
typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
-struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor> \
+struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1> \
{ \
typedef gebp_traits<EIGTYPE,EIGTYPE> Traits; \
\
static void run(Index rows, Index cols, Index depth, \
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
- EIGTYPE* res, Index resStride, \
+ EIGTYPE* res, Index resIncr, Index resStride, \
EIGTYPE alpha, \
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, \
GemmParallelInfo<Index>* /*info = 0*/) \
{ \
using std::conj; \
\
+ EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
+ eigen_assert(resIncr == 1); \
char transa, transb; \
BlasIndex m, n, k, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h
index a25197ab0..a01ac0588 100644
--- a/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h
+++ b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h
@@ -124,8 +124,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
- general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
- rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
+ general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
+ rows, cols, depth, aa_tmp.data(), aStride, _rhs, 1, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
\
/*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
} \
@@ -241,8 +241,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
- general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
- rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
+ general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
+ rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \
\
/*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
} \