aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/SolveTriangular.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-11-05 12:43:14 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-11-05 12:43:14 +0100
commit0e6c1170abab3aac8eb79b5662fdb9edae77e3cf (patch)
treeb396fd25eb27e55ca033e55a9d9ed6a42aeff05a /Eigen/src/Core/SolveTriangular.h
parentfe1353080ea5760daea332a8904edd78c0a9fb36 (diff)
trsv: add support for inner-stride!=1, reduce code instanciation, move implementation to a new products/XX.h file
Diffstat (limited to 'Eigen/src/Core/SolveTriangular.h')
-rw-r--r--Eigen/src/Core/SolveTriangular.h124
1 files changed, 30 insertions, 94 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h
index abbf57553..b950d2c31 100644
--- a/Eigen/src/Core/SolveTriangular.h
+++ b/Eigen/src/Core/SolveTriangular.h
@@ -27,6 +27,15 @@
namespace internal {
+// Forward declarations:
+// The following two routines are implemented in the products/TriangularSolver*.h files
+template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
+struct triangular_solve_vector;
+
+template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
+struct triangular_solve_matrix;
+
+// small helper struct extracting some traits on the underlying solver operation
template<typename Lhs, typename Rhs, int Side>
class trsolve_traits
{
@@ -51,111 +60,40 @@ template<typename Lhs, typename Rhs,
>
struct triangular_solver_selector;
-// forward and backward substitution, row-major, rhs is a vector
-template<typename Lhs, typename Rhs, int Mode>
-struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor,1>
+template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
+struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1>
{
typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar;
typedef blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
- typedef typename Lhs::Index Index;
- enum {
- IsLower = ((Mode&Lower)==Lower)
- };
- static void run(const Lhs& lhs, Rhs& other)
+ typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs;
+ static void run(const Lhs& lhs, Rhs& rhs)
{
- static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
- const Index size = lhs.cols();
- for(Index pi=IsLower ? 0 : size;
- IsLower ? pi<size : pi>0;
- IsLower ? pi+=PanelWidth : pi-=PanelWidth)
- {
- Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
-
- Index r = IsLower ? pi : size - pi; // remaining size
- if (r > 0)
- {
- // let's directly call the low level product function because:
- // 1 - it is faster to compile
- // 2 - it is slighlty faster at runtime
- Index startRow = IsLower ? pi : pi-actualPanelWidth;
- Index startCol = IsLower ? 0 : pi;
-
- general_matrix_vector_product<Index,LhsScalar,RowMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
- actualPanelWidth, r,
- &(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.outerStride(),
- &(other.coeffRef(startCol)), other.innerStride(),
- &other.coeffRef(startRow), other.innerStride(),
- RhsScalar(-1));
- }
-
- for(Index k=0; k<actualPanelWidth; ++k)
- {
- Index i = IsLower ? pi+k : pi-k-1;
- Index s = IsLower ? pi : i+1;
- if (k>0)
- other.coeffRef(i) -= (lhs.row(i).segment(s,k).transpose().cwiseProduct(other.segment(s,k))).sum();
+ // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1
- if(!(Mode & UnitDiag))
- other.coeffRef(i) /= lhs.coeff(i,i);
- }
+ bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1;
+ RhsScalar* actualRhs;
+ if(useRhsDirectly)
+ {
+ actualRhs = &rhs.coeffRef(0);
}
- }
-};
-
-// forward and backward substitution, column-major, rhs is a vector
-template<typename Lhs, typename Rhs, int Mode>
-struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor,1>
-{
- typedef typename Lhs::Scalar LhsScalar;
- typedef typename Rhs::Scalar RhsScalar;
- typedef blas_traits<Lhs> LhsProductTraits;
- typedef typename LhsProductTraits::ExtractType ActualLhsType;
- typedef typename Lhs::Index Index;
- enum {
- IsLower = ((Mode&Lower)==Lower)
- };
-
- static void run(const Lhs& lhs, Rhs& other)
- {
- static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
- ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
-
- const Index size = lhs.cols();
- for(Index pi=IsLower ? 0 : size;
- IsLower ? pi<size : pi>0;
- IsLower ? pi+=PanelWidth : pi-=PanelWidth)
+ else
{
- Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
- Index startBlock = IsLower ? pi : pi-actualPanelWidth;
- Index endBlock = IsLower ? pi + actualPanelWidth : 0;
+ actualRhs = ei_aligned_stack_new(RhsScalar,rhs.size());
+ MappedRhs(actualRhs,rhs.size()) = rhs;
+ }
- for(Index k=0; k<actualPanelWidth; ++k)
- {
- Index i = IsLower ? pi+k : pi-k-1;
- if(!(Mode & UnitDiag))
- other.coeffRef(i) /= lhs.coeff(i,i);
+
+ triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder>
+ ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
- Index r = actualPanelWidth - k - 1; // remaining size
- Index s = IsLower ? i+1 : i-r;
- if (r>0)
- other.segment(s,r) -= other.coeffRef(i) * Block<Lhs,Dynamic,1>(lhs, s, i, r, 1);
- }
- Index r = IsLower ? size - endBlock : startBlock; // remaining size
- if (r > 0)
- {
- // let's directly call the low level product function because:
- // 1 - it is faster to compile
- // 2 - it is slighlty faster at runtime
- general_matrix_vector_product<Index,LhsScalar,ColMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
- r, actualPanelWidth,
- &(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.outerStride(),
- &other.coeff(startBlock), other.innerStride(),
- &(other.coeffRef(endBlock, 0)), other.innerStride(), RhsScalar(-1));
- }
+ if(!useRhsDirectly)
+ {
+ rhs = MappedRhs(actualRhs, rhs.size());
+ ei_aligned_stack_delete(RhsScalar, actualRhs, rhs.size());
}
}
};
@@ -172,8 +110,6 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder
}
};
-template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
-struct triangular_solve_matrix;
// the rhs is a matrix
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder>