diff options
-rw-r--r-- | Eigen/src/Core/SolveTriangular.h | 84 | ||||
-rw-r--r-- | test/triangular.cpp | 12 |
2 files changed, 72 insertions, 24 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index 452d40a4c..3a65a8b27 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -26,28 +26,25 @@ #define EIGEN_SOLVETRIANGULAR_H template<typename Lhs, typename Rhs, - int Mode, // Upper/Lower | UnitDiag - int UpLo = (Mode & LowerTriangularBit) - ? LowerTriangular - : (Mode & UpperTriangularBit) - ? UpperTriangular - : -1, + int Mode, // can be Upper/Lower | UnitDiag + int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME + ? CompleteUnrolling : NoUnrolling, int StorageOrder = int(Lhs::Flags) & RowMajorBit > struct ei_triangular_solver_selector; // forward substitution, row-major -template<typename Lhs, typename Rhs, int Mode, int UpLo> -struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor> +template<typename Lhs, typename Rhs, int Mode> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor> { typedef typename Rhs::Scalar Scalar; typedef ei_product_factor_traits<Lhs> LhsProductTraits; typedef typename LhsProductTraits::ActualXprType ActualLhsType; enum { - IsLowerTriangular = (UpLo==LowerTriangular) + IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit) }; static void run(const Lhs& lhs, Rhs& other) - {//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n"; + { static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); @@ -90,12 +87,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor> }; // Implements the following configurations: -// - inv(LowerTriangular, ColMajor) * Column vector -// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector -// - inv(UpperTriangular, ColMajor) * Column vector -// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector -template<typename Lhs, typename Rhs, int Mode, int UpLo> -struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> +// - inv(LowerTriangular, ColMajor) * Column vectors +// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vectors +// - inv(UpperTriangular, ColMajor) * Column vectors +// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vectors +template<typename Lhs, typename Rhs, int Mode> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor> { typedef typename Rhs::Scalar Scalar; typedef typename ei_packet_traits<Scalar>::type Packet; @@ -103,11 +100,11 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> typedef typename LhsProductTraits::ActualXprType ActualLhsType; enum { PacketSize = ei_packet_traits<Scalar>::size, - IsLowerTriangular = (UpLo==LowerTriangular) + IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit) }; static void run(const Lhs& lhs, Rhs& other) - {//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n"; + { static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); @@ -154,6 +151,49 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> } }; +/*************************************************************************** +* meta-unrolling implementation +***************************************************************************/ + +template<typename Lhs, typename Rhs, int Mode, int Index, int Size, + bool Stop = Index==Size> +struct ei_triangular_solver_unroller; + +template<typename Lhs, typename Rhs, int Mode, int Index, int Size> +struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> { + enum { + IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit), + I = IsLowerTriangular ? Index : Size - Index - 1, + S = IsLowerTriangular ? 0 : I+1 + }; + static void run(const Lhs& lhs, Rhs& rhs) + { + if (Index>0) + rhs.coeffRef(I) -= ((lhs.row(I).template segment<Index>(S).transpose()) + .cwise()*(rhs.template segment<Index>(S))).sum(); + + if(!(Mode & UnitDiagBit)) + rhs.coeffRef(I) /= lhs.coeff(I,I); + + ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs); + } +}; + +template<typename Lhs, typename Rhs, int Mode, int Index, int Size> +struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> { + static void run(const Lhs& lhs, Rhs& rhs) {} +}; + +template<typename Lhs, typename Rhs, int Mode, int StorageOrder> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> { + static void run(const Lhs& lhs, Rhs& rhs) + { ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); } +}; + +/*************************************************************************** +* TriangularView methods +***************************************************************************/ + /** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other * * \nonstableyet @@ -161,7 +201,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> * \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here. * This function will const_cast it, so constness isn't honored here. * - * See MatrixBase:solveTriangular() for the details. + * See TriangularView:solve() for the details. */ template<typename MatrixType, unsigned int Mode> template<typename RhsDerived> @@ -198,8 +238,6 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>& * can be done by marked(), and that is automatically the case with expressions such as those returned * by extract(). * - * \addexample SolveTriangular \label How to solve a triangular system (aka. how to multiply the inverse of a triangular matrix by another one) - * * Example: \include MatrixBase_marked.cpp * Output: \verbinclude MatrixBase_marked.out * @@ -213,10 +251,10 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>& * * \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.: * \code - * M * T^1 <=> T.transpose().solveTriangularInPlace(M.transpose()); + * M * T^1 <=> T.transpose().solveInPlace(M.transpose()); * \endcode * - * \sa solveTriangularInPlace() + * \sa TriangularView::solveInPlace() */ template<typename Derived, unsigned int Mode> template<typename RhsDerived> diff --git a/test/triangular.cpp b/test/triangular.cpp index 0c03e987e..7c680a8ed 100644 --- a/test/triangular.cpp +++ b/test/triangular.cpp @@ -86,7 +86,17 @@ template<typename MatrixType> void triangular(const MatrixType& m) while (ei_abs2(m1(i,i))<1e-3) m1(i,i) = ei_random<Scalar>(); Transpose<MatrixType> trm4(m4); - // test back and forward subsitution + // test back and forward subsitution with a vector as the rhs + m3 = m1.template triangularView<Eigen::UpperTriangular>(); + VERIFY(v2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps)); + m3 = m1.template triangularView<Eigen::LowerTriangular>(); + VERIFY(v2.isApprox(m3.transpose() * (m1.transpose().template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps)); + m3 = m1.template triangularView<Eigen::UpperTriangular>(); + VERIFY(v2.isApprox(m3 * (m1.template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps)); + m3 = m1.template triangularView<Eigen::LowerTriangular>(); + VERIFY(v2.isApprox(m3.conjugate() * (m1.conjugate().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps)); + + // test back and forward subsitution with a matrix as the rhs m3 = m1.template triangularView<Eigen::UpperTriangular>(); VERIFY(m2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(m2)), largerEps)); m3 = m1.template triangularView<Eigen::LowerTriangular>(); |