diff options
author | 2009-07-10 10:41:26 +0200 | |
---|---|---|
committer | 2009-07-10 10:41:26 +0200 | |
commit | 1a1b2e9f27db619303e7f212f9bf5c58a2dd988c (patch) | |
tree | 7b89bee276ce87a11514650abe45b709c1fbb966 /Eigen/src/Core/SolveTriangular.h | |
parent | 8885d56928e45b3beda91e529845e369a17d0a91 (diff) |
finally directly calling the low-level products is faster
Diffstat (limited to 'Eigen/src/Core/SolveTriangular.h')
-rw-r--r-- | Eigen/src/Core/SolveTriangular.h | 81 |
1 files changed, 39 insertions, 42 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index b28078fa1..452d40a4c 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -27,7 +27,6 @@ template<typename Lhs, typename Rhs, int Mode, // Upper/Lower | UnitDiag -// bool ConjugateLhs, bool ConjugateRhs, int UpLo = (Mode & LowerTriangularBit) ? LowerTriangular : (Mode & UpperTriangularBit) @@ -38,15 +37,20 @@ template<typename Lhs, typename Rhs, struct ei_triangular_solver_selector; // forward substitution, row-major -template<typename Lhs, typename Rhs, int Mode, /*bool ConjugateLhs, bool ConjugateRhs,*/ int UpLo> -struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/UpLo,RowMajor> +template<typename Lhs, typename Rhs, int Mode, int UpLo> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor> { typedef typename Rhs::Scalar Scalar; + typedef ei_product_factor_traits<Lhs> LhsProductTraits; + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + enum { + IsLowerTriangular = (UpLo==LowerTriangular) + }; static void run(const Lhs& lhs, Rhs& other) - {//std::cerr << "row maj " << ConjugateLhs << " , " << ConjugateRhs -// << " " << typeid(Lhs).name() << "\n"; - static const int PanelWidth = 40; // TODO make this a user definable constant - static const bool IsLowerTriangular = (UpLo==LowerTriangular); + {//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n"; + static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; + const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); + const int size = lhs.cols(); for(int c=0 ; c<other.cols() ; ++c) { @@ -61,15 +65,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/ { int startRow = IsLowerTriangular ? pi : pi-actualPanelWidth; int startCol = IsLowerTriangular ? 0 : pi; -// Block<Rhs,Dynamic,1> target(other,startRow,c,actualPanelWidth,1); - -// ei_cache_friendly_product_rowmajor_times_vector<ConjugateLhs,ConjugateRhs>( -// &(lhs.const_cast_derived().coeffRef(startRow,startCol)), lhs.stride(), -// &(other.coeffRef(startCol, c)), r, -// target, Scalar(-1)); - other.col(c).segment(startRow,actualPanelWidth) -= - lhs.block(startRow,startCol,actualPanelWidth,r) - * other.col(c).segment(startCol,r); + Block<Rhs,Dynamic,1> target(other,startRow,c,actualPanelWidth,1); + + ei_cache_friendly_product_rowmajor_times_vector<LhsProductTraits::NeedToConjugate,false>( + &(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.stride(), + &(other.coeffRef(startCol, c)), r, + target, Scalar(-1)); } for(int k=0; k<actualPanelWidth; ++k) @@ -83,7 +84,6 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/ if(!(Mode & UnitDiagBit)) other.coeffRef(i,c) /= lhs.coeff(i,i); } - } } } @@ -94,17 +94,23 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/ // - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector // - inv(UpperTriangular, ColMajor) * Column vector // - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector -template<typename Lhs, typename Rhs, int Mode, /*bool ConjugateLhs, bool ConjugateRhs,*/ int UpLo> -struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/UpLo,ColMajor> +template<typename Lhs, typename Rhs, int Mode, int UpLo> +struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> { typedef typename Rhs::Scalar Scalar; typedef typename ei_packet_traits<Scalar>::type Packet; - enum { PacketSize = ei_packet_traits<Scalar>::size }; + typedef ei_product_factor_traits<Lhs> LhsProductTraits; + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + enum { + PacketSize = ei_packet_traits<Scalar>::size, + IsLowerTriangular = (UpLo==LowerTriangular) + }; static void run(const Lhs& lhs, Rhs& other) - {//std::cerr << "col maj " << ConjugateLhs << " , " << ConjugateRhs << "\n"; - static const int PanelWidth = 4; // TODO make this a user definable constant - static const bool IsLowerTriangular = (UpLo==LowerTriangular); + {//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n"; + static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; + const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); + const int size = lhs.cols(); for(int c=0 ; c<other.cols() ; ++c) { @@ -133,16 +139,15 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/ int r = IsLowerTriangular ? size - endBlock : startBlock; // remaining size if (r > 0) { -// ei_cache_friendly_product_colmajor_times_vector<ConjugateLhs,ConjugateRhs>( -// r, -// &(lhs.const_cast_derived().coeffRef(endBlock,startBlock)), lhs.stride(), -// other.col(c).segment(startBlock, actualPanelWidth), -// &(other.coeffRef(endBlock, c)), -// Scalar(-1)); - - other.col(c).segment(endBlock,r) -= - lhs.block(endBlock,startBlock,r,actualPanelWidth) - * other.col(c).segment(startBlock,actualPanelWidth); + // let's directly call this function because: + // 1 - it is faster to compile + // 2 - it is slighlty faster at runtime + ei_cache_friendly_product_colmajor_times_vector<LhsProductTraits::NeedToConjugate,false>( + r, + &(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.stride(), + other.col(c).segment(startBlock, actualPanelWidth), + &(other.coeffRef(endBlock, c)), + Scalar(-1)); } } } @@ -168,21 +173,13 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>& ei_assert(!(Mode & ZeroDiagBit)); ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit)); -// typedef ei_product_factor_traits<MatrixType> LhsProductTraits; -// typedef ei_product_factor_traits<RhsDerived> RhsProductTraits; -// typedef typename LhsProductTraits::ActualXprType ActualLhsType; -// typedef typename RhsProductTraits::ActualXprType ActualRhsType; -// const ActualLhsType& actualLhs = LhsProductTraits::extract(_expression()); -// ActualRhsType& actualRhs = const_cast<ActualRhsType&>(RhsProductTraits::extract(rhs)); - enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit }; -// std::cerr << typeid(MatrixType).name() << "\n"; typedef typename ei_meta_if<copy, typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy; RhsCopy rhsCopy(rhs); ei_triangular_solver_selector<MatrixType, typename ei_unref<RhsCopy>::type, - Mode/*, LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate*/>::run(_expression(), rhsCopy); + Mode>::run(_expression(), rhsCopy); if (copy) rhs = rhsCopy; |