diff options
author | Gael Guennebaud <g.gael@free.fr> | 2009-07-31 13:18:19 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2009-07-31 13:18:19 +0200 |
commit | a156f5a8696a22a2738ab54bcc77a7565d5e9019 (patch) | |
tree | 8c4c2bf99c6ce1a90d5a87ae914016679423f07e | |
parent | ff20a2ba9423161efd2796832b896f9a4a1a3006 (diff) |
faster trsm kernel and fix a couple of issues
-rw-r--r-- | Eigen/src/Core/SolveTriangular.h | 4 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularSolverMatrix.h | 54 |
2 files changed, 22 insertions, 36 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index f60ef1c03..9b67dd580 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -49,7 +49,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor { static const int PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; ActualLhsType actualLhs = LhsProductTraits::extract(lhs); - + const int size = lhs.cols(); for(int pi=IsLowerTriangular ? 0 : size; IsLowerTriangular ? pi<size : pi>0; @@ -224,7 +224,7 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>& ei_assert(!(Mode & ZeroDiagBit)); ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit)); - enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit }; + enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit && RhsDerived::IsVectorAtCompileTime }; typedef typename ei_meta_if<copy, typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy; RhsCopy rhsCopy(rhs); diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index 7842f8703..e49fac956 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -25,7 +25,7 @@ #ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H -// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix +// if the rhs is row major, let's transpose the product template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder> struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor> { @@ -34,22 +34,16 @@ struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,Row const Scalar* tri, int triStride, Scalar* _other, int otherStride) { - ei_triangular_solve_matrix< Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft, - (Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular, - !Conjugate, TriStorageOrder, ColMajor> + (Mode&UnitDiagBit) | ((Mode&UpperTriangular) ? LowerTriangular : UpperTriangular), + NumTraits<Scalar>::IsComplex && Conjugate, + TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor> ::run(size, cols, tri, triStride, _other, otherStride); - -// Map<Matrix<Scalar,Dynamic,Dynamic> > other(_other, otherStride, cols); -// Matrix<Scalar,Dynamic,Dynamic> aux = other.block(0,0,size,cols); -// ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,ColMajor> -// ::run(size, cols, tri, triStride, aux.data(), aux.stride()); -// other.block(0,0,size,cols) = aux; } }; -/* Optimized triangular solver with multiple right hand side (_TRSM) +/* Optimized triangular solver with multiple right hand side and the triangular matrix on the left */ template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder> struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> @@ -190,11 +184,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd Scalar* _other, int otherStride) { int rows = otherSize; -// ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride); -// ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride); - - Map<Matrix<Scalar,Dynamic,Dynamic,TriStorageOrder> > rhs(_tri,size,size); - Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor> > lhs(_other,rows,size); + ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride); + ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride); typedef ei_product_blocking_traits<Scalar> Blocking; enum { @@ -203,8 +194,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular }; - int kc = std::min<int>(/*Blocking::Max_kc/4*/32,size); // cache block size along the K direction - int mc = std::min<int>(/*Blocking::Max_mc*/32,size); // cache block size along the M direction + int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction + int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize); @@ -214,7 +205,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs; ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel; ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel; - ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false> pack_lhs; for(int k2=IsLowerTriangular ? size : 0; IsLowerTriangular ? k2>0 : k2<size; @@ -224,7 +214,7 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd int actual_k2 = IsLowerTriangular ? k2-actual_kc : k2 ; int startPanel = IsLowerTriangular ? 0 : k2+actual_kc; - int rs = IsLowerTriangular ? actual_k2 : size - actual_k2 - actual_kc; + int rs = IsLowerTriangular ? actual_k2 : size - actual_k2 - actual_kc; Scalar* geb = blockB+actual_kc*actual_kc*Blocking::PacketSize; if (rs>0) pack_rhs(geb, &rhs(actual_k2,startPanel), triStride, -1, actual_kc, rs); @@ -239,8 +229,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0; int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2; -// std::cerr << "$ " << k2 << " " << j2 << " " << actual_j2 << " " << panelOffset << " " << panelLength << "\n"; - if (panelLength>0) pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize, &rhs(actual_k2+panelOffset, actual_j2), triStride, -1, @@ -269,7 +257,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2; // GEBP - //if (lengthTarget>0) if(panelLength>0) { gebp_kernel(&lhs(i2,absolute_j2), otherStride, @@ -284,18 +271,17 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd { int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k; - Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j)); - for (int i=0; i<actual_mc; ++i) + Scalar* r = &lhs(i2,j); + for (int k3=0; k3<k; ++k3) { - int absolute_i = i2+i; - Scalar b = 0; - for (int k3=0; k3<k; ++k3) - if(IsLowerTriangular) - b += lhs(absolute_i,j+1+k3) * conj(rhs(j+1+k3,j)); - else - b += lhs(absolute_i,absolute_j2+k3) * conj(rhs(absolute_j2+k3,j)); - lhs(absolute_i,j) = (lhs(absolute_i,j) - b)*a; + Scalar b = conj(rhs(IsLowerTriangular ? j+1+k3 : absolute_j2+k3,j)); + Scalar* a = &lhs(i2,IsLowerTriangular ? j+1+k3 : absolute_j2+k3); + for (int i=0; i<actual_mc; ++i) + r[i] -= a[i] * b; } + Scalar b = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j)); + for (int i=0; i<actual_mc; ++i) + r[i] *= b; } // pack the just computed part of lhs to A @@ -304,7 +290,7 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd actual_kc, j2); } } - + if (rs>0) gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb, actual_mc, actual_kc, rs); |