aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/SolveTriangular.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2009-07-10 10:41:26 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2009-07-10 10:41:26 +0200
commit1a1b2e9f27db619303e7f212f9bf5c58a2dd988c (patch)
tree7b89bee276ce87a11514650abe45b709c1fbb966 /Eigen/src/Core/SolveTriangular.h
parent8885d56928e45b3beda91e529845e369a17d0a91 (diff)
finally directly calling the low-level products is faster
Diffstat (limited to 'Eigen/src/Core/SolveTriangular.h')
-rw-r--r--Eigen/src/Core/SolveTriangular.h81
1 files changed, 39 insertions, 42 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h
index b28078fa1..452d40a4c 100644
--- a/Eigen/src/Core/SolveTriangular.h
+++ b/Eigen/src/Core/SolveTriangular.h
@@ -27,7 +27,6 @@
template<typename Lhs, typename Rhs,
int Mode, // Upper/Lower | UnitDiag
-// bool ConjugateLhs, bool ConjugateRhs,
int UpLo = (Mode & LowerTriangularBit)
? LowerTriangular
: (Mode & UpperTriangularBit)
@@ -38,15 +37,20 @@ template<typename Lhs, typename Rhs,
struct ei_triangular_solver_selector;
// forward substitution, row-major
-template<typename Lhs, typename Rhs, int Mode, /*bool ConjugateLhs, bool ConjugateRhs,*/ int UpLo>
-struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/UpLo,RowMajor>
+template<typename Lhs, typename Rhs, int Mode, int UpLo>
+struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
+ typedef ei_product_factor_traits<Lhs> LhsProductTraits;
+ typedef typename LhsProductTraits::ActualXprType ActualLhsType;
+ enum {
+ IsLowerTriangular = (UpLo==LowerTriangular)
+ };
static void run(const Lhs& lhs, Rhs& other)
- {//std::cerr << "row maj " << ConjugateLhs << " , " << ConjugateRhs
-// << " " << typeid(Lhs).name() << "\n";
- static const int PanelWidth = 40; // TODO make this a user definable constant
- static const bool IsLowerTriangular = (UpLo==LowerTriangular);
+ {//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n";
+ static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
+ const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
+
const int size = lhs.cols();
for(int c=0 ; c<other.cols() ; ++c)
{
@@ -61,15 +65,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/
{
int startRow = IsLowerTriangular ? pi : pi-actualPanelWidth;
int startCol = IsLowerTriangular ? 0 : pi;
-// Block<Rhs,Dynamic,1> target(other,startRow,c,actualPanelWidth,1);
-
-// ei_cache_friendly_product_rowmajor_times_vector<ConjugateLhs,ConjugateRhs>(
-// &(lhs.const_cast_derived().coeffRef(startRow,startCol)), lhs.stride(),
-// &(other.coeffRef(startCol, c)), r,
-// target, Scalar(-1));
- other.col(c).segment(startRow,actualPanelWidth) -=
- lhs.block(startRow,startCol,actualPanelWidth,r)
- * other.col(c).segment(startCol,r);
+ Block<Rhs,Dynamic,1> target(other,startRow,c,actualPanelWidth,1);
+
+ ei_cache_friendly_product_rowmajor_times_vector<LhsProductTraits::NeedToConjugate,false>(
+ &(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.stride(),
+ &(other.coeffRef(startCol, c)), r,
+ target, Scalar(-1));
}
for(int k=0; k<actualPanelWidth; ++k)
@@ -83,7 +84,6 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/
if(!(Mode & UnitDiagBit))
other.coeffRef(i,c) /= lhs.coeff(i,i);
}
-
}
}
}
@@ -94,17 +94,23 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector
// - inv(UpperTriangular, ColMajor) * Column vector
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector
-template<typename Lhs, typename Rhs, int Mode, /*bool ConjugateLhs, bool ConjugateRhs,*/ int UpLo>
-struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/UpLo,ColMajor>
+template<typename Lhs, typename Rhs, int Mode, int UpLo>
+struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type Packet;
- enum { PacketSize = ei_packet_traits<Scalar>::size };
+ typedef ei_product_factor_traits<Lhs> LhsProductTraits;
+ typedef typename LhsProductTraits::ActualXprType ActualLhsType;
+ enum {
+ PacketSize = ei_packet_traits<Scalar>::size,
+ IsLowerTriangular = (UpLo==LowerTriangular)
+ };
static void run(const Lhs& lhs, Rhs& other)
- {//std::cerr << "col maj " << ConjugateLhs << " , " << ConjugateRhs << "\n";
- static const int PanelWidth = 4; // TODO make this a user definable constant
- static const bool IsLowerTriangular = (UpLo==LowerTriangular);
+ {//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n";
+ static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
+ const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
+
const int size = lhs.cols();
for(int c=0 ; c<other.cols() ; ++c)
{
@@ -133,16 +139,15 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/
int r = IsLowerTriangular ? size - endBlock : startBlock; // remaining size
if (r > 0)
{
-// ei_cache_friendly_product_colmajor_times_vector<ConjugateLhs,ConjugateRhs>(
-// r,
-// &(lhs.const_cast_derived().coeffRef(endBlock,startBlock)), lhs.stride(),
-// other.col(c).segment(startBlock, actualPanelWidth),
-// &(other.coeffRef(endBlock, c)),
-// Scalar(-1));
-
- other.col(c).segment(endBlock,r) -=
- lhs.block(endBlock,startBlock,r,actualPanelWidth)
- * other.col(c).segment(startBlock,actualPanelWidth);
+ // let's directly call this function because:
+ // 1 - it is faster to compile
+ // 2 - it is slighlty faster at runtime
+ ei_cache_friendly_product_colmajor_times_vector<LhsProductTraits::NeedToConjugate,false>(
+ r,
+ &(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.stride(),
+ other.col(c).segment(startBlock, actualPanelWidth),
+ &(other.coeffRef(endBlock, c)),
+ Scalar(-1));
}
}
}
@@ -168,21 +173,13 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
ei_assert(!(Mode & ZeroDiagBit));
ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit));
-// typedef ei_product_factor_traits<MatrixType> LhsProductTraits;
-// typedef ei_product_factor_traits<RhsDerived> RhsProductTraits;
-// typedef typename LhsProductTraits::ActualXprType ActualLhsType;
-// typedef typename RhsProductTraits::ActualXprType ActualRhsType;
-// const ActualLhsType& actualLhs = LhsProductTraits::extract(_expression());
-// ActualRhsType& actualRhs = const_cast<ActualRhsType&>(RhsProductTraits::extract(rhs));
-
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit };
-// std::cerr << typeid(MatrixType).name() << "\n";
typedef typename ei_meta_if<copy,
typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy;
RhsCopy rhsCopy(rhs);
ei_triangular_solver_selector<MatrixType, typename ei_unref<RhsCopy>::type,
- Mode/*, LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate*/>::run(_expression(), rhsCopy);
+ Mode>::run(_expression(), rhsCopy);
if (copy)
rhs = rhsCopy;