aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/SparseCore/TriangularSolver.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/SparseCore/TriangularSolver.h')
-rw-r--r--Eigen/src/SparseCore/TriangularSolver.h56
1 files changed, 28 insertions, 28 deletions
diff --git a/Eigen/src/SparseCore/TriangularSolver.h b/Eigen/src/SparseCore/TriangularSolver.h
index dd55522a7..98062e9c6 100644
--- a/Eigen/src/SparseCore/TriangularSolver.h
+++ b/Eigen/src/SparseCore/TriangularSolver.h
@@ -29,8 +29,11 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
+ typedef typename evaluator<Lhs>::type LhsEval;
+ typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
+ LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=0; i<lhs.rows(); ++i)
@@ -38,7 +41,7 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
Scalar tmp = other.coeff(i,col);
Scalar lastVal(0);
Index lastIndex = 0;
- for(typename Lhs::InnerIterator it(lhs, i); it; ++it)
+ for(LhsIterator it(lhsEval, i); it; ++it)
{
lastVal = it.value();
lastIndex = it.index();
@@ -64,15 +67,18 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
+ typedef typename evaluator<Lhs>::type LhsEval;
+ typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
+ LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=lhs.rows()-1 ; i>=0 ; --i)
{
Scalar tmp = other.coeff(i,col);
Scalar l_ii = 0;
- typename Lhs::InnerIterator it(lhs, i);
+ LhsIterator it(lhsEval, i);
while(it && it.index()<i)
++it;
if(!(Mode & UnitDiag))
@@ -88,10 +94,8 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
tmp -= it.value() * other.coeff(it.index(),col);
}
- if (Mode & UnitDiag)
- other.coeffRef(i,col) = tmp;
- else
- other.coeffRef(i,col) = tmp/l_ii;
+ if (Mode & UnitDiag) other.coeffRef(i,col) = tmp;
+ else other.coeffRef(i,col) = tmp/l_ii;
}
}
}
@@ -103,8 +107,11 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
+ typedef typename evaluator<Lhs>::type LhsEval;
+ typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
+ LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=0; i<lhs.cols(); ++i)
@@ -112,7 +119,7 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
Scalar& tmp = other.coeffRef(i,col);
if (tmp!=Scalar(0)) // optimization when other is actually sparse
{
- typename Lhs::InnerIterator it(lhs, i);
+ LhsIterator it(lhsEval, i);
while(it && it.index()<i)
++it;
if(!(Mode & UnitDiag))
@@ -136,8 +143,11 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Index Index;
+ typedef typename evaluator<Lhs>::type LhsEval;
+ typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
static void run(const Lhs& lhs, Rhs& other)
{
+ LhsEval lhsEval(lhs);
for(Index col=0 ; col<other.cols() ; ++col)
{
for(Index i=lhs.cols()-1; i>=0; --i)
@@ -148,13 +158,13 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
if(!(Mode & UnitDiag))
{
// TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
- typename Lhs::ReverseInnerIterator it(lhs, i);
+ LhsIterator it(lhsEval, i);
while(it && it.index()!=i)
- --it;
+ ++it;
eigen_assert(it && it.index()==i);
other.coeffRef(i,col) /= it.value();
}
- typename Lhs::InnerIterator it(lhs, i);
+ LhsIterator it(lhsEval, i);
for(; it && it.index()<i; ++it)
other.coeffRef(it.index(), col) -= tmp * it.value();
}
@@ -165,11 +175,11 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
} // end namespace internal
-template<typename ExpressionType,int Mode>
+template<typename ExpressionType,unsigned int Mode>
template<typename OtherDerived>
-void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDerived>& other) const
+void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(MatrixBase<OtherDerived>& other) const
{
- eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
+ eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
@@ -178,22 +188,12 @@ void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDer
typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
OtherCopy otherCopy(other.derived());
- internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(m_matrix, otherCopy);
+ internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(derived().nestedExpression(), otherCopy);
if (copy)
other = otherCopy;
}
-template<typename ExpressionType,int Mode>
-template<typename OtherDerived>
-typename internal::plain_matrix_type_column_major<OtherDerived>::type
-SparseTriangularView<ExpressionType,Mode>::solve(const MatrixBase<OtherDerived>& other) const
-{
- typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
- solveInPlace(res);
- return res;
-}
-
// pure sparse path
namespace internal {
@@ -290,11 +290,11 @@ struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
} // end namespace internal
-template<typename ExpressionType,int Mode>
+template<typename ExpressionType,unsigned int Mode>
template<typename OtherDerived>
-void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
+void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
{
- eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
+ eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
// enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
@@ -303,7 +303,7 @@ void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<Ot
// typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
// OtherCopy otherCopy(other.derived());
- internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(m_matrix, other.derived());
+ internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(derived().nestedExpression(), other.derived());
// if (copy)
// other = otherCopy;