diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-11-05 12:43:14 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-11-05 12:43:14 +0100 |
commit | 0e6c1170abab3aac8eb79b5662fdb9edae77e3cf (patch) | |
tree | b396fd25eb27e55ca033e55a9d9ed6a42aeff05a /Eigen/src/Core/SolveTriangular.h | |
parent | fe1353080ea5760daea332a8904edd78c0a9fb36 (diff) |
trsv: add support for inner-stride!=1, reduce code instanciation, move implementation to a new products/XX.h file
Diffstat (limited to 'Eigen/src/Core/SolveTriangular.h')
-rw-r--r-- | Eigen/src/Core/SolveTriangular.h | 124 |
1 files changed, 30 insertions, 94 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index abbf57553..b950d2c31 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -27,6 +27,15 @@ namespace internal { +// Forward declarations: +// The following two routines are implemented in the products/TriangularSolver*.h files +template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder> +struct triangular_solve_vector; + +template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> +struct triangular_solve_matrix; + +// small helper struct extracting some traits on the underlying solver operation template<typename Lhs, typename Rhs, int Side> class trsolve_traits { @@ -51,111 +60,40 @@ template<typename Lhs, typename Rhs, > struct triangular_solver_selector; -// forward and backward substitution, row-major, rhs is a vector -template<typename Lhs, typename Rhs, int Mode> -struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor,1> +template<typename Lhs, typename Rhs, int Mode, int StorageOrder> +struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1> { typedef typename Lhs::Scalar LhsScalar; typedef typename Rhs::Scalar RhsScalar; typedef blas_traits<Lhs> LhsProductTraits; typedef typename LhsProductTraits::ExtractType ActualLhsType; - typedef typename Lhs::Index Index; - enum { - IsLower = ((Mode&Lower)==Lower) - }; - static void run(const Lhs& lhs, Rhs& other) + typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs; + static void run(const Lhs& lhs, Rhs& rhs) { - static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; ActualLhsType actualLhs = LhsProductTraits::extract(lhs); - const Index size = lhs.cols(); - for(Index pi=IsLower ? 0 : size; - IsLower ? pi<size : pi>0; - IsLower ? pi+=PanelWidth : pi-=PanelWidth) - { - Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth); - - Index r = IsLower ? pi : size - pi; // remaining size - if (r > 0) - { - // let's directly call the low level product function because: - // 1 - it is faster to compile - // 2 - it is slighlty faster at runtime - Index startRow = IsLower ? pi : pi-actualPanelWidth; - Index startCol = IsLower ? 0 : pi; - - general_matrix_vector_product<Index,LhsScalar,RowMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run( - actualPanelWidth, r, - &(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.outerStride(), - &(other.coeffRef(startCol)), other.innerStride(), - &other.coeffRef(startRow), other.innerStride(), - RhsScalar(-1)); - } - - for(Index k=0; k<actualPanelWidth; ++k) - { - Index i = IsLower ? pi+k : pi-k-1; - Index s = IsLower ? pi : i+1; - if (k>0) - other.coeffRef(i) -= (lhs.row(i).segment(s,k).transpose().cwiseProduct(other.segment(s,k))).sum(); + // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1 - if(!(Mode & UnitDiag)) - other.coeffRef(i) /= lhs.coeff(i,i); - } + bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1; + RhsScalar* actualRhs; + if(useRhsDirectly) + { + actualRhs = &rhs.coeffRef(0); } - } -}; - -// forward and backward substitution, column-major, rhs is a vector -template<typename Lhs, typename Rhs, int Mode> -struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor,1> -{ - typedef typename Lhs::Scalar LhsScalar; - typedef typename Rhs::Scalar RhsScalar; - typedef blas_traits<Lhs> LhsProductTraits; - typedef typename LhsProductTraits::ExtractType ActualLhsType; - typedef typename Lhs::Index Index; - enum { - IsLower = ((Mode&Lower)==Lower) - }; - - static void run(const Lhs& lhs, Rhs& other) - { - static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; - ActualLhsType actualLhs = LhsProductTraits::extract(lhs); - - const Index size = lhs.cols(); - for(Index pi=IsLower ? 0 : size; - IsLower ? pi<size : pi>0; - IsLower ? pi+=PanelWidth : pi-=PanelWidth) + else { - Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth); - Index startBlock = IsLower ? pi : pi-actualPanelWidth; - Index endBlock = IsLower ? pi + actualPanelWidth : 0; + actualRhs = ei_aligned_stack_new(RhsScalar,rhs.size()); + MappedRhs(actualRhs,rhs.size()) = rhs; + } - for(Index k=0; k<actualPanelWidth; ++k) - { - Index i = IsLower ? pi+k : pi-k-1; - if(!(Mode & UnitDiag)) - other.coeffRef(i) /= lhs.coeff(i,i); + + triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder> + ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs); - Index r = actualPanelWidth - k - 1; // remaining size - Index s = IsLower ? i+1 : i-r; - if (r>0) - other.segment(s,r) -= other.coeffRef(i) * Block<Lhs,Dynamic,1>(lhs, s, i, r, 1); - } - Index r = IsLower ? size - endBlock : startBlock; // remaining size - if (r > 0) - { - // let's directly call the low level product function because: - // 1 - it is faster to compile - // 2 - it is slighlty faster at runtime - general_matrix_vector_product<Index,LhsScalar,ColMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run( - r, actualPanelWidth, - &(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.outerStride(), - &other.coeff(startBlock), other.innerStride(), - &(other.coeffRef(endBlock, 0)), other.innerStride(), RhsScalar(-1)); - } + if(!useRhsDirectly) + { + rhs = MappedRhs(actualRhs, rhs.size()); + ei_aligned_stack_delete(RhsScalar, actualRhs, rhs.size()); } } }; @@ -172,8 +110,6 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder } }; -template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> -struct triangular_solve_matrix; // the rhs is a matrix template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder> |