diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-06-15 22:00:34 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-06-15 22:00:34 +0200 |
commit | 2e792d1f42e895175e9536e141456326da1176ed (patch) | |
tree | 3dcf012ac8dc3da469cdb9adc3bd165d8da31255 | |
parent | 134ca4acb3860c2521ef73508023b9c9d8cac4ec (diff) |
* make the triangular matrix * matrix product works with trapezoidal matrices
* extend the trmm unit test for unit diagonal
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixMatrix.h | 51 | ||||
-rw-r--r-- | test/product_trmm.cpp | 13 |
2 files changed, 38 insertions, 26 deletions
diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h index 7f6177ea7..967deaffb 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h @@ -49,7 +49,7 @@ // } // }; -/* Optimized selfadjoint matrix * matrix (_SYMM) product built on top of +/* Optimized triangular matrix * matrix (_TRMM++) product built on top of * the general matrix matrix product. */ template <typename Scalar, typename Index, @@ -68,7 +68,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular, RhsStorageOrder,ConjugateRhs,RowMajor> { static EIGEN_STRONG_INLINE void run( - Index size, Index otherSize, + Index rows, Index cols, Index depth, const Scalar* lhs, Index lhsStride, const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, @@ -82,7 +82,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs, ColMajor> - ::run(size, otherSize, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha); + ::run(rows, cols, depth, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha); } }; @@ -96,14 +96,12 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true, { static EIGEN_DONT_INLINE void run( - Index size, Index cols, + Index rows, Index cols, Index depth, const Scalar* _lhs, Index lhsStride, const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha) { - Index rows = size; - ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); @@ -116,8 +114,8 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true, IsLower = (Mode&Lower) == Lower }; - Index kc = std::min<Index>(Blocking::Max_kc/4,size); // cache block size along the K direction - Index mc = std::min<Index>(Blocking::Max_mc,rows); // cache block size along the M direction + Index kc = std::min<Index>(Blocking::Max_kc/4,depth); // cache block size along the K direction + Index mc = std::min<Index>(Blocking::Max_mc,rows); // cache block size along the M direction Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*cols; @@ -133,20 +131,27 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true, ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder> pack_lhs; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs; - for(Index k2=IsLower ? size : 0; - IsLower ? k2>0 : k2<size; + for(Index k2=IsLower ? depth : 0; + IsLower ? k2>0 : k2<depth; IsLower ? k2-=kc : k2+=kc) { - const Index actual_kc = std::min(IsLower ? k2 : size-k2, kc); + Index actual_kc = std::min(IsLower ? k2 : depth-k2, kc); Index actual_k2 = IsLower ? k2-actual_kc : k2; + if((!IsLower)&&(k2<rows)&&(k2+actual_kc>rows)) + { + actual_kc = rows-k2; + k2 = k2+actual_kc-kc; + } + pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, alpha, actual_kc, cols); // the selected lhs's panel has to be split in three different parts: // 1 - the part which is above the diagonal block => skip it // 2 - the diagonal block => special kernel // 3 - the panel below the diagonal block => GEPP - // the block diagonal + // the block diagonal, if any + if(IsLower || actual_k2<rows) { // for each small vertical panels of lhs for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth) @@ -186,7 +191,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true, // the part below the diagonal => GEPP { Index start = IsLower ? k2 : 0; - Index end = IsLower ? size : actual_k2; + Index end = IsLower ? rows : actual_k2; for(Index i2=start; i2<end; i2+=mc) { const Index actual_mc = std::min(i2+mc,end)-i2; @@ -214,14 +219,12 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false, { static EIGEN_DONT_INLINE void run( - Index size, Index rows, + Index rows, Index cols, Index depth, const Scalar* _lhs, Index lhsStride, const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha) { - Index cols = size; - ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); @@ -234,8 +237,8 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false, IsLower = (Mode&Lower) == Lower }; - Index kc = std::min<Index>(Blocking::Max_kc/4,size); // cache block size along the K direction - Index mc = std::min<Index>(Blocking::Max_mc,rows); // cache block size along the M direction + Index kc = std::min<Index>(Blocking::Max_kc/4,depth); // cache block size along the K direction + Index mc = std::min<Index>(Blocking::Max_mc,rows); // cache block size along the M direction Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*cols; @@ -251,13 +254,13 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false, ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs; ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder,true> pack_rhs_panel; - for(Index k2=IsLower ? 0 : size; - IsLower ? k2<size : k2>0; + for(Index k2=IsLower ? 0 : depth; + IsLower ? k2<depth : k2>0; IsLower ? k2+=kc : k2-=kc) { - const Index actual_kc = std::min(IsLower ? size-k2 : k2, kc); + const Index actual_kc = std::min(IsLower ? depth-k2 : k2, kc); Index actual_k2 = IsLower ? k2 : k2-actual_kc; - Index rs = IsLower ? actual_k2 : size - k2; + Index rs = IsLower ? actual_k2 : depth - k2; Scalar* geb = blockB+actual_kc*actual_kc; pack_rhs(geb, &rhs(actual_k2,IsLower ? 0 : k2), rhsStride, alpha, actual_kc, rs); @@ -355,11 +358,11 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> (ei_traits<_ActualRhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, (ei_traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor> ::run( - lhs.rows(), LhsIsTriangular ? rhs.cols() : lhs.rows(), // sizes + lhs.rows(), rhs.cols(), lhs.cols(),// LhsIsTriangular ? rhs.cols() : lhs.rows(), // sizes &lhs.coeff(0,0), lhs.outerStride(), // lhs info &rhs.coeff(0,0), rhs.outerStride(), // rhs info &dst.coeffRef(0,0), dst.outerStride(), // result info - actualAlpha // alpha + actualAlpha // alpha ); } }; diff --git a/test/product_trmm.cpp b/test/product_trmm.cpp index 69e97f7aa..e8580cbd2 100644 --- a/test/product_trmm.cpp +++ b/test/product_trmm.cpp @@ -28,8 +28,11 @@ template<typename Scalar> void trmm(int size,int othersize) { typedef typename NumTraits<Scalar>::Real RealScalar; - Matrix<Scalar,Dynamic,Dynamic,ColMajor> tri(size,size), upTri(size,size), loTri(size,size); - Matrix<Scalar,Dynamic,Dynamic,ColMajor> ge1(size,othersize), ge2(10,size), ge3; + typedef Matrix<Scalar,Dynamic,Dynamic,ColMajor> MatrixType; + + MatrixType tri(size,size), upTri(size,size), loTri(size,size), + unitUpTri(size,size), unitLoTri(size,size); + MatrixType ge1(size,othersize), ge2(10,size), ge3; Matrix<Scalar,Dynamic,Dynamic,RowMajor> rge3; Scalar s1 = ei_random<Scalar>(), @@ -38,6 +41,8 @@ template<typename Scalar> void trmm(int size,int othersize) tri.setRandom(); loTri = tri.template triangularView<Lower>(); upTri = tri.template triangularView<Upper>(); + unitLoTri = tri.template triangularView<UnitLower>(); + unitUpTri = tri.template triangularView<UnitUpper>(); ge1.setRandom(); ge2.setRandom(); @@ -57,6 +62,10 @@ template<typename Scalar> void trmm(int size,int othersize) VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<Upper>() * ge2.adjoint(), loTri.adjoint() * ge2.adjoint()); VERIFY_IS_APPROX( ge3 = tri.adjoint().template triangularView<Lower>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint()); VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<Lower>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint()); + + VERIFY_IS_APPROX( ge3 = tri.template triangularView<UnitLower>() * ge1, unitLoTri * ge1); + VERIFY_IS_APPROX(rge3 = tri.template triangularView<UnitLower>() * ge1, unitLoTri * ge1); + VERIFY_IS_APPROX( ge3 = (s1*tri).adjoint().template triangularView<UnitUpper>() * ge2.adjoint(), ei_conj(s1) * unitLoTri.adjoint() * ge2.adjoint()); } void test_product_trmm() |