diff options
author | Gael Guennebaud <g.gael@free.fr> | 2012-06-12 11:33:50 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2012-06-12 11:33:50 +0200 |
commit | 924c7a9300101de181038bab04eda64d84a671ee (patch) | |
tree | e6eb672d6fca715e3c9dc5dfe3af1b11607afa54 /Eigen/src/Core/products/TriangularSolverMatrix.h | |
parent | bc580bbffb6f0fd0dd2fa335612e76d9ed545476 (diff) |
avoid dynamic allocation for fixed size triangular solving
Diffstat (limited to 'Eigen/src/Core/products/TriangularSolverMatrix.h')
-rw-r--r-- | Eigen/src/Core/products/TriangularSolverMatrix.h | 53 |
1 files changed, 26 insertions, 27 deletions
diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index 4bba12cfe..0dd94638b 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -36,14 +36,15 @@ struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder, static EIGEN_DONT_INLINE void run( Index size, Index cols, const Scalar* tri, Index triStride, - Scalar* _other, Index otherStride) + Scalar* _other, Index otherStride, + level3_blocking<Scalar,Scalar>& blocking) { triangular_solve_matrix< Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft, (Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper), NumTraits<Scalar>::IsComplex && Conjugate, TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor> - ::run(size, cols, tri, triStride, _other, otherStride); + ::run(size, cols, tri, triStride, _other, otherStride, blocking); } }; @@ -55,7 +56,8 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO static EIGEN_DONT_INLINE void run( Index size, Index otherSize, const Scalar* _tri, Index triStride, - Scalar* _other, Index otherStride) + Scalar* _other, Index otherStride, + level3_blocking<Scalar,Scalar>& blocking) { Index cols = otherSize; const_blas_data_mapper<Scalar, Index, TriStorageOrder> tri(_tri,triStride); @@ -67,17 +69,16 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO IsLower = (Mode&Lower) == Lower }; - Index kc = size; // cache block size along the K direction - Index mc = size; // cache block size along the M direction - Index nc = cols; // cache block size along the N direction - computeProductBlockingSizes<Scalar,Scalar,4>(kc, mc, nc); + Index kc = blocking.kc(); // cache block size along the K direction + Index mc = (std::min)(size,blocking.mc()); // cache block size along the M direction + std::size_t sizeA = kc*mc; + std::size_t sizeB = kc*cols; std::size_t sizeW = kc*Traits::WorkSpaceFactor; - std::size_t sizeB = sizeW + kc*cols; - ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0); - ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0); - Scalar* blockB = allocatedBlockB + sizeW; - Scalar* blockW = allocatedBlockB; + + ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW()); conj_if<Conjugate> conj; gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel; @@ -181,7 +182,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO { pack_lhs(blockA, &tri(i2, IsLower ? k2 : k2-kc), triStride, actual_kc, actual_mc); - gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1)); + gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1), -1, -1, 0, 0, blockW); } } } @@ -197,7 +198,8 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage static EIGEN_DONT_INLINE void run( Index size, Index otherSize, const Scalar* _tri, Index triStride, - Scalar* _other, Index otherStride) + Scalar* _other, Index otherStride, + level3_blocking<Scalar,Scalar>& blocking) { Index rows = otherSize; const_blas_data_mapper<Scalar, Index, TriStorageOrder> rhs(_tri,triStride); @@ -210,19 +212,16 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage IsLower = (Mode&Lower) == Lower }; -// Index kc = std::min<Index>(Traits::Max_kc/4,size); // cache block size along the K direction -// Index mc = std::min<Index>(Traits::Max_mc,size); // cache block size along the M direction - // check that !!!! - Index kc = size; // cache block size along the K direction - Index mc = size; // cache block size along the M direction - Index nc = rows; // cache block size along the N direction - computeProductBlockingSizes<Scalar,Scalar,4>(kc, mc, nc); + Index kc = blocking.kc(); // cache block size along the K direction + Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction + std::size_t sizeA = kc*mc; + std::size_t sizeB = kc*size; std::size_t sizeW = kc*Traits::WorkSpaceFactor; - std::size_t sizeB = sizeW + kc*size; - ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0); - ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0); - Scalar* blockB = allocatedBlockB + sizeW; + + ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW()); conj_if<Conjugate> conj; gebp_kernel<Scalar,Scalar, Index, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel; @@ -289,7 +288,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage Scalar(-1), actual_kc, actual_kc, // strides panelOffset, panelOffset, // offsets - allocatedBlockB); // workspace + blockW); // workspace } // unblocked triangular solve @@ -320,7 +319,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage if (rs>0) gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb, actual_mc, actual_kc, rs, Scalar(-1), - -1, -1, 0, 0, allocatedBlockB); + -1, -1, 0, 0, blockW); } } } |