diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-11-05 12:54:32 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-11-05 12:54:32 +0100 |
commit | 3fdea699b80c429738ac0af8c9b7479594b90583 (patch) | |
tree | ab9f6f36fdf3632961516e4526878dc869199860 /Eigen/src/Core/SolveTriangular.h | |
parent | 0e6c1170abab3aac8eb79b5662fdb9edae77e3cf (diff) |
trsv: simplifications/cleaning
Diffstat (limited to 'Eigen/src/Core/SolveTriangular.h')
-rw-r--r-- | Eigen/src/Core/SolveTriangular.h | 37 |
1 files changed, 12 insertions, 25 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index b950d2c31..d85f967cb 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -29,7 +29,7 @@ namespace internal { // Forward declarations: // The following two routines are implemented in the products/TriangularSolver*.h files -template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder> +template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder> struct triangular_solve_vector; template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> @@ -55,13 +55,12 @@ template<typename Lhs, typename Rhs, int Side, // can be OnTheLeft/OnTheRight int Mode, // can be Upper/Lower | UnitDiag int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling, - int StorageOrder = (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor, int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors > struct triangular_solver_selector; -template<typename Lhs, typename Rhs, int Mode, int StorageOrder> -struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1> +template<typename Lhs, typename Rhs, int Side, int Mode> +struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1> { typedef typename Lhs::Scalar LhsScalar; typedef typename Rhs::Scalar RhsScalar; @@ -86,8 +85,8 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde MappedRhs(actualRhs,rhs.size()) = rhs; } - - triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder> + triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Side, Mode, LhsProductTraits::NeedToConjugate, + (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor> ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs); if(!useRhsDirectly) @@ -98,22 +97,9 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde } }; -// transpose OnTheRight cases for vectors -template<typename Lhs, typename Rhs, int Mode, int Unrolling, int StorageOrder> -struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder,1> -{ - static void run(const Lhs& lhs, Rhs& rhs) - { - Transpose<Rhs> rhsTr(rhs); - Transpose<Lhs> lhsTr(lhs); - triangular_solver_selector<Transpose<Lhs>,Transpose<Rhs>,OnTheLeft,TriangularView<Lhs,Mode>::TransposeMode>::run(lhsTr,rhsTr); - } -}; - - // the rhs is a matrix -template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder> -struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dynamic> +template<typename Lhs, typename Rhs, int Side, int Mode> +struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic> { typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Index Index; @@ -122,7 +108,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dyn static void run(const Lhs& lhs, Rhs& rhs) { const ActualLhsType actualLhs = LhsProductTraits::extract(lhs); - triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,StorageOrder, + triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor, (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor> ::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeff(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride()); } @@ -146,7 +132,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> { static void run(const Lhs& lhs, Rhs& rhs) { if (Index>0) - rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose().cwiseProduct(rhs.template segment<Index>(S)).sum(); + rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose() + .cwiseProduct(rhs.template segment<Index>(S)).sum(); if(!(Mode & UnitDiag)) rhs.coeffRef(I) /= lhs.coeff(I,I); @@ -160,8 +147,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> { static void run(const Lhs&, Rhs&) {} }; -template<typename Lhs, typename Rhs, int Mode, int StorageOrder> -struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,StorageOrder,1> { +template<typename Lhs, typename Rhs, int Mode> +struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> { static void run(const Lhs& lhs, Rhs& rhs) { triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); } }; |