aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-11-05 12:43:14 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-11-05 12:43:14 +0100
commit0e6c1170abab3aac8eb79b5662fdb9edae77e3cf (patch)
treeb396fd25eb27e55ca033e55a9d9ed6a42aeff05a /Eigen
parentfe1353080ea5760daea332a8904edd78c0a9fb36 (diff)
trsv: add support for inner-stride!=1, reduce code instanciation, move implementation to a new products/XX.h file
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/Core1
-rw-r--r--Eigen/src/Core/SolveTriangular.h124
-rw-r--r--Eigen/src/Core/products/TriangularSolverVector.h138
3 files changed, 169 insertions, 94 deletions
diff --git a/Eigen/Core b/Eigen/Core
index 4401a4e14..5e81d10d1 100644
--- a/Eigen/Core
+++ b/Eigen/Core
@@ -315,6 +315,7 @@ using std::size_t;
#include "src/Core/products/TriangularMatrixVector.h"
#include "src/Core/products/TriangularMatrixMatrix.h"
#include "src/Core/products/TriangularSolverMatrix.h"
+#include "src/Core/products/TriangularSolverVector.h"
#include "src/Core/BandMatrix.h"
#include "src/Core/BooleanRedux.h"
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h
index abbf57553..b950d2c31 100644
--- a/Eigen/src/Core/SolveTriangular.h
+++ b/Eigen/src/Core/SolveTriangular.h
@@ -27,6 +27,15 @@
namespace internal {
+// Forward declarations:
+// The following two routines are implemented in the products/TriangularSolver*.h files
+template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
+struct triangular_solve_vector;
+
+template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
+struct triangular_solve_matrix;
+
+// small helper struct extracting some traits on the underlying solver operation
template<typename Lhs, typename Rhs, int Side>
class trsolve_traits
{
@@ -51,111 +60,40 @@ template<typename Lhs, typename Rhs,
>
struct triangular_solver_selector;
-// forward and backward substitution, row-major, rhs is a vector
-template<typename Lhs, typename Rhs, int Mode>
-struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor,1>
+template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
+struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1>
{
typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar;
typedef blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
- typedef typename Lhs::Index Index;
- enum {
- IsLower = ((Mode&Lower)==Lower)
- };
- static void run(const Lhs& lhs, Rhs& other)
+ typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs;
+ static void run(const Lhs& lhs, Rhs& rhs)
{
- static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
- const Index size = lhs.cols();
- for(Index pi=IsLower ? 0 : size;
- IsLower ? pi<size : pi>0;
- IsLower ? pi+=PanelWidth : pi-=PanelWidth)
- {
- Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
-
- Index r = IsLower ? pi : size - pi; // remaining size
- if (r > 0)
- {
- // let's directly call the low level product function because:
- // 1 - it is faster to compile
- // 2 - it is slighlty faster at runtime
- Index startRow = IsLower ? pi : pi-actualPanelWidth;
- Index startCol = IsLower ? 0 : pi;
-
- general_matrix_vector_product<Index,LhsScalar,RowMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
- actualPanelWidth, r,
- &(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.outerStride(),
- &(other.coeffRef(startCol)), other.innerStride(),
- &other.coeffRef(startRow), other.innerStride(),
- RhsScalar(-1));
- }
-
- for(Index k=0; k<actualPanelWidth; ++k)
- {
- Index i = IsLower ? pi+k : pi-k-1;
- Index s = IsLower ? pi : i+1;
- if (k>0)
- other.coeffRef(i) -= (lhs.row(i).segment(s,k).transpose().cwiseProduct(other.segment(s,k))).sum();
+ // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1
- if(!(Mode & UnitDiag))
- other.coeffRef(i) /= lhs.coeff(i,i);
- }
+ bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1;
+ RhsScalar* actualRhs;
+ if(useRhsDirectly)
+ {
+ actualRhs = &rhs.coeffRef(0);
}
- }
-};
-
-// forward and backward substitution, column-major, rhs is a vector
-template<typename Lhs, typename Rhs, int Mode>
-struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor,1>
-{
- typedef typename Lhs::Scalar LhsScalar;
- typedef typename Rhs::Scalar RhsScalar;
- typedef blas_traits<Lhs> LhsProductTraits;
- typedef typename LhsProductTraits::ExtractType ActualLhsType;
- typedef typename Lhs::Index Index;
- enum {
- IsLower = ((Mode&Lower)==Lower)
- };
-
- static void run(const Lhs& lhs, Rhs& other)
- {
- static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
- ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
-
- const Index size = lhs.cols();
- for(Index pi=IsLower ? 0 : size;
- IsLower ? pi<size : pi>0;
- IsLower ? pi+=PanelWidth : pi-=PanelWidth)
+ else
{
- Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
- Index startBlock = IsLower ? pi : pi-actualPanelWidth;
- Index endBlock = IsLower ? pi + actualPanelWidth : 0;
+ actualRhs = ei_aligned_stack_new(RhsScalar,rhs.size());
+ MappedRhs(actualRhs,rhs.size()) = rhs;
+ }
- for(Index k=0; k<actualPanelWidth; ++k)
- {
- Index i = IsLower ? pi+k : pi-k-1;
- if(!(Mode & UnitDiag))
- other.coeffRef(i) /= lhs.coeff(i,i);
+
+ triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder>
+ ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
- Index r = actualPanelWidth - k - 1; // remaining size
- Index s = IsLower ? i+1 : i-r;
- if (r>0)
- other.segment(s,r) -= other.coeffRef(i) * Block<Lhs,Dynamic,1>(lhs, s, i, r, 1);
- }
- Index r = IsLower ? size - endBlock : startBlock; // remaining size
- if (r > 0)
- {
- // let's directly call the low level product function because:
- // 1 - it is faster to compile
- // 2 - it is slighlty faster at runtime
- general_matrix_vector_product<Index,LhsScalar,ColMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
- r, actualPanelWidth,
- &(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.outerStride(),
- &other.coeff(startBlock), other.innerStride(),
- &(other.coeffRef(endBlock, 0)), other.innerStride(), RhsScalar(-1));
- }
+ if(!useRhsDirectly)
+ {
+ rhs = MappedRhs(actualRhs, rhs.size());
+ ei_aligned_stack_delete(RhsScalar, actualRhs, rhs.size());
}
}
};
@@ -172,8 +110,6 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder
}
};
-template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
-struct triangular_solve_matrix;
// the rhs is a matrix
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder>
diff --git a/Eigen/src/Core/products/TriangularSolverVector.h b/Eigen/src/Core/products/TriangularSolverVector.h
new file mode 100644
index 000000000..fcf8bcae0
--- /dev/null
+++ b/Eigen/src/Core/products/TriangularSolverVector.h
@@ -0,0 +1,138 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// Eigen is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 3 of the License, or (at your option) any later version.
+//
+// Alternatively, you can redistribute it and/or
+// modify it under the terms of the GNU General Public License as
+// published by the Free Software Foundation; either version 2 of
+// the License, or (at your option) any later version.
+//
+// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
+// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License and a copy of the GNU General Public License along with
+// Eigen. If not, see <http://www.gnu.org/licenses/>.
+
+#ifndef EIGEN_TRIANGULAR_SOLVER_VECTOR_H
+#define EIGEN_TRIANGULAR_SOLVER_VECTOR_H
+
+namespace internal {
+
+// forward and backward substitution, row-major, rhs is a vector
+template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
+struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, RowMajor>
+{
+ enum {
+ IsLower = ((Mode&Lower)==Lower)
+ };
+ static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
+ {
+ typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
+ const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
+ typename internal::conditional<
+ Conjugate,
+ const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
+ const LhsMap&>
+ ::type cjLhs(lhs);
+ static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
+ for(Index pi=IsLower ? 0 : size;
+ IsLower ? pi<size : pi>0;
+ IsLower ? pi+=PanelWidth : pi-=PanelWidth)
+ {
+ Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
+
+ Index r = IsLower ? pi : size - pi; // remaining size
+ if (r > 0)
+ {
+ // let's directly call the low level product function because:
+ // 1 - it is faster to compile
+ // 2 - it is slighlty faster at runtime
+ Index startRow = IsLower ? pi : pi-actualPanelWidth;
+ Index startCol = IsLower ? 0 : pi;
+
+ general_matrix_vector_product<Index,LhsScalar,RowMajor,Conjugate,RhsScalar,false>::run(
+ actualPanelWidth, r,
+ &(lhs.coeff(startRow,startCol)), lhsStride,
+ rhs + startCol, 1,
+ rhs + startRow, 1,
+ RhsScalar(-1));
+ }
+
+ for(Index k=0; k<actualPanelWidth; ++k)
+ {
+ Index i = IsLower ? pi+k : pi-k-1;
+ Index s = IsLower ? pi : i+1;
+ if (k>0)
+ rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,k))).sum();
+
+ if(!(Mode & UnitDiag))
+ rhs[i] /= lhs(i,i);
+ }
+ }
+ }
+};
+
+// forward and backward substitution, column-major, rhs is a vector
+template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
+struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, ColMajor>
+{
+ enum {
+ IsLower = ((Mode&Lower)==Lower)
+ };
+ static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
+ {
+ typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
+ const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
+ typename internal::conditional<Conjugate,
+ const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
+ const LhsMap&
+ >::type cjLhs(lhs);
+ static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
+
+ for(Index pi=IsLower ? 0 : size;
+ IsLower ? pi<size : pi>0;
+ IsLower ? pi+=PanelWidth : pi-=PanelWidth)
+ {
+ Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth);
+ Index startBlock = IsLower ? pi : pi-actualPanelWidth;
+ Index endBlock = IsLower ? pi + actualPanelWidth : 0;
+
+ for(Index k=0; k<actualPanelWidth; ++k)
+ {
+ Index i = IsLower ? pi+k : pi-k-1;
+ if(!(Mode & UnitDiag))
+ rhs[i] /= cjLhs.coeff(i,i);
+
+ Index r = actualPanelWidth - k - 1; // remaining size
+ Index s = IsLower ? i+1 : i-r;
+ if (r>0)
+ Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,r) -= rhs[i] * cjLhs.col(i).segment(s,r);
+ }
+ Index r = IsLower ? size - endBlock : startBlock; // remaining size
+ if (r > 0)
+ {
+ // let's directly call the low level product function because:
+ // 1 - it is faster to compile
+ // 2 - it is slighlty faster at runtime
+ general_matrix_vector_product<Index,LhsScalar,ColMajor,Conjugate,RhsScalar,false>::run(
+ r, actualPanelWidth,
+ &(lhs.coeff(endBlock,startBlock)), lhsStride,
+ rhs+startBlock, 1,
+ rhs+endBlock, 1, RhsScalar(-1));
+ }
+ }
+ }
+};
+
+} // end namespace internal
+
+#endif // EIGEN_TRIANGULAR_SOLVER_VECTOR_H