diff options
author | Gael Guennebaud <g.gael@free.fr> | 2009-07-26 13:01:37 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2009-07-26 13:01:37 +0200 |
commit | f3fde74695eff236fe24b05ffb053d3890346420 (patch) | |
tree | e67b2323c788de9cb55686499c8d08c8ca885a1b /Eigen/src/Core | |
parent | 282e18da4915943b0dc2ed0140cfd037aaa02a70 (diff) |
finalize trsm: works in all situations, and it is now used by solve() and solveInPlace()
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/SolveTriangular.h | 31 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularSolverMatrix.h | 111 |
2 files changed, 89 insertions, 53 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index cb162ca91..d0656eacb 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -29,13 +29,14 @@ template<typename Lhs, typename Rhs, int Mode, // can be Upper/Lower | UnitDiag int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME ? CompleteUnrolling : NoUnrolling, - int StorageOrder = int(Lhs::Flags) & RowMajorBit + int StorageOrder = int(Lhs::Flags) & RowMajorBit, + int RhsCols = Rhs::ColsAtCompileTime > struct ei_triangular_solver_selector; -// forward and backward substitution, row-major +// forward and backward substitution, row-major, rhs is a vector template<typename Lhs, typename Rhs, int Mode> -struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor,1> { typedef typename Rhs::Scalar Scalar; typedef ei_blas_traits<Lhs> LhsProductTraits; @@ -89,9 +90,9 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor> } }; -// forward and backward substitution, column-major +// forward and backward substitution, column-major, rhs is a vector template<typename Lhs, typename Rhs, int Mode> -struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor,1> { typedef typename Rhs::Scalar Scalar; typedef typename ei_packet_traits<Scalar>::type Packet; @@ -150,6 +151,24 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor> } }; +template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, int Mode> +struct ei_triangular_solve_matrix; + +// the rhs is a matrix +template<typename Lhs, typename Rhs, int Mode, int StorageOrder, int RhsCols> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,StorageOrder,RhsCols> +{ + typedef typename Rhs::Scalar Scalar; + typedef ei_blas_traits<Lhs> LhsProductTraits; + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + static void run(const Lhs& lhs, Rhs& rhs) + { + const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); + ei_triangular_solve_matrix<Scalar,StorageOrder,LhsProductTraits::NeedToConjugate,Rhs::Flags&RowMajorBit,Mode> + ::run(lhs.rows(), rhs.cols(), &actualLhs.coeff(0,0), actualLhs.stride(), &rhs.coeffRef(0,0), rhs.stride()); + } +}; + /*************************************************************************** * meta-unrolling implementation ***************************************************************************/ @@ -184,7 +203,7 @@ struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> { }; template<typename Lhs, typename Rhs, int Mode, int StorageOrder> -struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> { +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder,1> { static void run(const Lhs& lhs, Rhs& rhs) { ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); } }; diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index eeb445f0b..550076f68 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -26,63 +26,37 @@ #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H template<typename Scalar, int nr> -struct ei_gemm_pack_rhs_panel +struct ei_gemm_pack_rhs_panel; + +// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix +template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode> +struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,RowMajor,Mode> { - enum { PacketSize = ei_packet_traits<Scalar>::size }; - void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset) + static EIGEN_DONT_INLINE void run( + int size, int cols, + const Scalar* lhs, int lhsStride, + Scalar* _rhs, int rhsStride) { - int packet_cols = (cols/nr) * nr; - int count = 0; - for(int j2=0; j2<packet_cols; j2+=nr) - { - // skip what we have before - count += PacketSize * nr * offset; - const Scalar* b0 = &rhs[(j2+0)*rhsStride]; - const Scalar* b1 = &rhs[(j2+1)*rhsStride]; - const Scalar* b2 = &rhs[(j2+2)*rhsStride]; - const Scalar* b3 = &rhs[(j2+3)*rhsStride]; - for(int k=0; k<depth; k++) - { - ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[k])); - ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b1[k])); - if(nr==4) ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b2[k])); - if(nr==4) ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b3[k])); - count += nr*PacketSize; - } - // skip what we have after - count += PacketSize * nr * (stride-offset-depth); - } - // copy the remaining columns one at a time (nr==1) - for(int j2=packet_cols; j2<cols; ++j2) - { - count += PacketSize * offset; - const Scalar* b0 = &rhs[(j2+0)*rhsStride]; - for(int k=0; k<depth; k++) - { - ei_pstore(&blockB[count], ei_pset1(alpha*b0[k])); - count += PacketSize; - } - count += PacketSize * (stride-offset-depth); - } + Map<Matrix<Scalar,Dynamic,Dynamic> > rhs(_rhs, rhsStride, cols); + Matrix<Scalar,Dynamic,Dynamic> aux = rhs.block(0,0,size,cols); + ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode> + ::run(size, cols, lhs, lhsStride, aux.data(), aux.stride()); + rhs.block(0,0,size,cols) = aux; } }; /* Optimized triangular solver with multiple right hand side (_TRSM) */ -template <typename Scalar, - int LhsStorageOrder, - int RhsStorageOrder, - int Mode> -struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder> +template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode> +struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode> { - static EIGEN_DONT_INLINE void run( int size, int cols, const Scalar* _lhs, int lhsStride, Scalar* _rhs, int rhsStride) { ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride); - ei_blas_data_mapper <Scalar, RhsStorageOrder> rhs(_rhs,rhsStride); + ei_blas_data_mapper <Scalar, ColMajor> rhs(_rhs,rhsStride); typedef ei_product_blocking_traits<Scalar> Blocking; enum { @@ -96,7 +70,8 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder> Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); - ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<false,false> > gebp_kernel; + ei_conj_if<ConjugateLhs> conj; + ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,false> > gebp_kernel; ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs; for(int k2=IsLowerTriangular ? 0 : size; @@ -131,7 +106,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder> int s = IsLowerTriangular ? k2+k1 : i+1; int rs = actualPanelWidth - k - 1; // remaining size - Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/lhs(i,i); + Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(lhs(i,i)); for (int j=0; j<cols; ++j) { if (LhsStorageOrder==RowMajor) @@ -140,7 +115,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder> const Scalar* l = &lhs(i,s); Scalar* r = &rhs(s,j); for (int i3=0; i3<k; ++i3) - b += l[i3] * r[i3]; + b += conj(l[i3]) * r[i3]; rhs(i,j) = (rhs(i,j) - b)*a; } @@ -151,7 +126,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder> Scalar* r = &rhs(s,j); const Scalar* l = &lhs(s,i); for (int i3=0;i3<rs;++i3) - r[i3] -= b * l[i3]; + r[i3] -= b * conj(l[i3]); } } } @@ -199,4 +174,46 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder> } }; +template<typename Scalar, int nr> +struct ei_gemm_pack_rhs_panel +{ + enum { PacketSize = ei_packet_traits<Scalar>::size }; + void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset) + { + int packet_cols = (cols/nr) * nr; + int count = 0; + for(int j2=0; j2<packet_cols; j2+=nr) + { + // skip what we have before + count += PacketSize * nr * offset; + const Scalar* b0 = &rhs[(j2+0)*rhsStride]; + const Scalar* b1 = &rhs[(j2+1)*rhsStride]; + const Scalar* b2 = &rhs[(j2+2)*rhsStride]; + const Scalar* b3 = &rhs[(j2+3)*rhsStride]; + for(int k=0; k<depth; k++) + { + ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[k])); + ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b1[k])); + if(nr==4) ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b2[k])); + if(nr==4) ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b3[k])); + count += nr*PacketSize; + } + // skip what we have after + count += PacketSize * nr * (stride-offset-depth); + } + // copy the remaining columns one at a time (nr==1) + for(int j2=packet_cols; j2<cols; ++j2) + { + count += PacketSize * offset; + const Scalar* b0 = &rhs[(j2+0)*rhsStride]; + for(int k=0; k<depth; k++) + { + ei_pstore(&blockB[count], ei_pset1(alpha*b0[k])); + count += PacketSize; + } + count += PacketSize * (stride-offset-depth); + } + } +}; + #endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H |