From f3fde74695eff236fe24b05ffb053d3890346420 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sun, 26 Jul 2009 13:01:37 +0200 Subject: finalize trsm: works in all situations, and it is now used by solve() and solveInPlace() --- Eigen/src/Core/products/TriangularSolverMatrix.h | 111 +++++++++++++---------- 1 file changed, 64 insertions(+), 47 deletions(-) (limited to 'Eigen/src/Core/products/TriangularSolverMatrix.h') diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index eeb445f0b..550076f68 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -26,63 +26,37 @@ #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H template -struct ei_gemm_pack_rhs_panel +struct ei_gemm_pack_rhs_panel; + +// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix +template +struct ei_triangular_solve_matrix { - enum { PacketSize = ei_packet_traits::size }; - void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset) + static EIGEN_DONT_INLINE void run( + int size, int cols, + const Scalar* lhs, int lhsStride, + Scalar* _rhs, int rhsStride) { - int packet_cols = (cols/nr) * nr; - int count = 0; - for(int j2=0; j2 > rhs(_rhs, rhsStride, cols); + Matrix aux = rhs.block(0,0,size,cols); + ei_triangular_solve_matrix + ::run(size, cols, lhs, lhsStride, aux.data(), aux.stride()); + rhs.block(0,0,size,cols) = aux; } }; /* Optimized triangular solver with multiple right hand side (_TRSM) */ -template -struct ei_triangular_solve_matrix// +template +struct ei_triangular_solve_matrix { - static EIGEN_DONT_INLINE void run( int size, int cols, const Scalar* _lhs, int lhsStride, Scalar* _rhs, int rhsStride) { ei_const_blas_data_mapper lhs(_lhs,lhsStride); - ei_blas_data_mapper rhs(_rhs,rhsStride); + ei_blas_data_mapper rhs(_rhs,rhsStride); typedef ei_product_blocking_traits Blocking; enum { @@ -96,7 +70,8 @@ struct ei_triangular_solve_matrix// Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); - ei_gebp_kernel > gebp_kernel; + ei_conj_if conj; + ei_gebp_kernel > gebp_kernel; ei_gemm_pack_lhs pack_lhs; for(int k2=IsLowerTriangular ? 0 : size; @@ -131,7 +106,7 @@ struct ei_triangular_solve_matrix// int s = IsLowerTriangular ? k2+k1 : i+1; int rs = actualPanelWidth - k - 1; // remaining size - Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/lhs(i,i); + Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(lhs(i,i)); for (int j=0; j const Scalar* l = &lhs(i,s); Scalar* r = &rhs(s,j); for (int i3=0; i3 Scalar* r = &rhs(s,j); const Scalar* l = &lhs(s,i); for (int i3=0;i3 } }; +template +struct ei_gemm_pack_rhs_panel +{ + enum { PacketSize = ei_packet_traits::size }; + void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset) + { + int packet_cols = (cols/nr) * nr; + int count = 0; + for(int j2=0; j2