diff options
author | Gael Guennebaud <g.gael@free.fr> | 2009-07-10 11:30:46 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2009-07-10 11:30:46 +0200 |
commit | b47dea8b7aeab10cf584f2d3275192d90d8df2ed (patch) | |
tree | f9647716407f4fe897d8d277e9f99cfda08a9c86 /Eigen/src | |
parent | 1a1b2e9f27db619303e7f212f9bf5c58a2dd988c (diff) |
add a meta unroller for the triangular solver (only for vectors as rhs)
Diffstat (limited to 'Eigen/src')
-rw-r--r-- | Eigen/src/Core/SolveTriangular.h | 84 |
1 files changed, 61 insertions, 23 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index 452d40a4c..3a65a8b27 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -26,28 +26,25 @@ #define EIGEN_SOLVETRIANGULAR_H template<typename Lhs, typename Rhs, - int Mode, // Upper/Lower | UnitDiag - int UpLo = (Mode & LowerTriangularBit) - ? LowerTriangular - : (Mode & UpperTriangularBit) - ? UpperTriangular - : -1, + int Mode, // can be Upper/Lower | UnitDiag + int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME + ? CompleteUnrolling : NoUnrolling, int StorageOrder = int(Lhs::Flags) & RowMajorBit > struct ei_triangular_solver_selector; // forward substitution, row-major -template<typename Lhs, typename Rhs, int Mode, int UpLo> -struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor> +template<typename Lhs, typename Rhs, int Mode> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor> { typedef typename Rhs::Scalar Scalar; typedef ei_product_factor_traits<Lhs> LhsProductTraits; typedef typename LhsProductTraits::ActualXprType ActualLhsType; enum { - IsLowerTriangular = (UpLo==LowerTriangular) + IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit) }; static void run(const Lhs& lhs, Rhs& other) - {//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n"; + { static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); @@ -90,12 +87,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor> }; // Implements the following configurations: -// - inv(LowerTriangular, ColMajor) * Column vector -// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector -// - inv(UpperTriangular, ColMajor) * Column vector -// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector -template<typename Lhs, typename Rhs, int Mode, int UpLo> -struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> +// - inv(LowerTriangular, ColMajor) * Column vectors +// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vectors +// - inv(UpperTriangular, ColMajor) * Column vectors +// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vectors +template<typename Lhs, typename Rhs, int Mode> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor> { typedef typename Rhs::Scalar Scalar; typedef typename ei_packet_traits<Scalar>::type Packet; @@ -103,11 +100,11 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> typedef typename LhsProductTraits::ActualXprType ActualLhsType; enum { PacketSize = ei_packet_traits<Scalar>::size, - IsLowerTriangular = (UpLo==LowerTriangular) + IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit) }; static void run(const Lhs& lhs, Rhs& other) - {//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n"; + { static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); @@ -154,6 +151,49 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> } }; +/*************************************************************************** +* meta-unrolling implementation +***************************************************************************/ + +template<typename Lhs, typename Rhs, int Mode, int Index, int Size, + bool Stop = Index==Size> +struct ei_triangular_solver_unroller; + +template<typename Lhs, typename Rhs, int Mode, int Index, int Size> +struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> { + enum { + IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit), + I = IsLowerTriangular ? Index : Size - Index - 1, + S = IsLowerTriangular ? 0 : I+1 + }; + static void run(const Lhs& lhs, Rhs& rhs) + { + if (Index>0) + rhs.coeffRef(I) -= ((lhs.row(I).template segment<Index>(S).transpose()) + .cwise()*(rhs.template segment<Index>(S))).sum(); + + if(!(Mode & UnitDiagBit)) + rhs.coeffRef(I) /= lhs.coeff(I,I); + + ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs); + } +}; + +template<typename Lhs, typename Rhs, int Mode, int Index, int Size> +struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> { + static void run(const Lhs& lhs, Rhs& rhs) {} +}; + +template<typename Lhs, typename Rhs, int Mode, int StorageOrder> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> { + static void run(const Lhs& lhs, Rhs& rhs) + { ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); } +}; + +/*************************************************************************** +* TriangularView methods +***************************************************************************/ + /** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other * * \nonstableyet @@ -161,7 +201,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> * \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here. * This function will const_cast it, so constness isn't honored here. * - * See MatrixBase:solveTriangular() for the details. + * See TriangularView:solve() for the details. */ template<typename MatrixType, unsigned int Mode> template<typename RhsDerived> @@ -198,8 +238,6 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>& * can be done by marked(), and that is automatically the case with expressions such as those returned * by extract(). * - * \addexample SolveTriangular \label How to solve a triangular system (aka. how to multiply the inverse of a triangular matrix by another one) - * * Example: \include MatrixBase_marked.cpp * Output: \verbinclude MatrixBase_marked.out * @@ -213,10 +251,10 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>& * * \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.: * \code - * M * T^1 <=> T.transpose().solveTriangularInPlace(M.transpose()); + * M * T^1 <=> T.transpose().solveInPlace(M.transpose()); * \endcode * - * \sa solveTriangularInPlace() + * \sa TriangularView::solveInPlace() */ template<typename Derived, unsigned int Mode> template<typename RhsDerived> |