aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2009-07-10 11:30:46 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2009-07-10 11:30:46 +0200
commitb47dea8b7aeab10cf584f2d3275192d90d8df2ed (patch)
treef9647716407f4fe897d8d277e9f99cfda08a9c86
parent1a1b2e9f27db619303e7f212f9bf5c58a2dd988c (diff)
add a meta unroller for the triangular solver (only for vectors as rhs)
-rw-r--r--Eigen/src/Core/SolveTriangular.h84
-rw-r--r--test/triangular.cpp12
2 files changed, 72 insertions, 24 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h
index 452d40a4c..3a65a8b27 100644
--- a/Eigen/src/Core/SolveTriangular.h
+++ b/Eigen/src/Core/SolveTriangular.h
@@ -26,28 +26,25 @@
#define EIGEN_SOLVETRIANGULAR_H
template<typename Lhs, typename Rhs,
- int Mode, // Upper/Lower | UnitDiag
- int UpLo = (Mode & LowerTriangularBit)
- ? LowerTriangular
- : (Mode & UpperTriangularBit)
- ? UpperTriangular
- : -1,
+ int Mode, // can be Upper/Lower | UnitDiag
+ int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME
+ ? CompleteUnrolling : NoUnrolling,
int StorageOrder = int(Lhs::Flags) & RowMajorBit
>
struct ei_triangular_solver_selector;
// forward substitution, row-major
-template<typename Lhs, typename Rhs, int Mode, int UpLo>
-struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
+template<typename Lhs, typename Rhs, int Mode>
+struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef ei_product_factor_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
enum {
- IsLowerTriangular = (UpLo==LowerTriangular)
+ IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
};
static void run(const Lhs& lhs, Rhs& other)
- {//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n";
+ {
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
@@ -90,12 +87,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
};
// Implements the following configurations:
-// - inv(LowerTriangular, ColMajor) * Column vector
-// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector
-// - inv(UpperTriangular, ColMajor) * Column vector
-// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector
-template<typename Lhs, typename Rhs, int Mode, int UpLo>
-struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
+// - inv(LowerTriangular, ColMajor) * Column vectors
+// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vectors
+// - inv(UpperTriangular, ColMajor) * Column vectors
+// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vectors
+template<typename Lhs, typename Rhs, int Mode>
+struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type Packet;
@@ -103,11 +100,11 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
enum {
PacketSize = ei_packet_traits<Scalar>::size,
- IsLowerTriangular = (UpLo==LowerTriangular)
+ IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
};
static void run(const Lhs& lhs, Rhs& other)
- {//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n";
+ {
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
@@ -154,6 +151,49 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
}
};
+/***************************************************************************
+* meta-unrolling implementation
+***************************************************************************/
+
+template<typename Lhs, typename Rhs, int Mode, int Index, int Size,
+ bool Stop = Index==Size>
+struct ei_triangular_solver_unroller;
+
+template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
+struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
+ enum {
+ IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit),
+ I = IsLowerTriangular ? Index : Size - Index - 1,
+ S = IsLowerTriangular ? 0 : I+1
+ };
+ static void run(const Lhs& lhs, Rhs& rhs)
+ {
+ if (Index>0)
+ rhs.coeffRef(I) -= ((lhs.row(I).template segment<Index>(S).transpose())
+ .cwise()*(rhs.template segment<Index>(S))).sum();
+
+ if(!(Mode & UnitDiagBit))
+ rhs.coeffRef(I) /= lhs.coeff(I,I);
+
+ ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs);
+ }
+};
+
+template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
+struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
+ static void run(const Lhs& lhs, Rhs& rhs) {}
+};
+
+template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
+struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> {
+ static void run(const Lhs& lhs, Rhs& rhs)
+ { ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
+};
+
+/***************************************************************************
+* TriangularView methods
+***************************************************************************/
+
/** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other
*
* \nonstableyet
@@ -161,7 +201,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
* This function will const_cast it, so constness isn't honored here.
*
- * See MatrixBase:solveTriangular() for the details.
+ * See TriangularView:solve() for the details.
*/
template<typename MatrixType, unsigned int Mode>
template<typename RhsDerived>
@@ -198,8 +238,6 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
* can be done by marked(), and that is automatically the case with expressions such as those returned
* by extract().
*
- * \addexample SolveTriangular \label How to solve a triangular system (aka. how to multiply the inverse of a triangular matrix by another one)
- *
* Example: \include MatrixBase_marked.cpp
* Output: \verbinclude MatrixBase_marked.out
*
@@ -213,10 +251,10 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
*
* \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.:
* \code
- * M * T^1 <=> T.transpose().solveTriangularInPlace(M.transpose());
+ * M * T^1 <=> T.transpose().solveInPlace(M.transpose());
* \endcode
*
- * \sa solveTriangularInPlace()
+ * \sa TriangularView::solveInPlace()
*/
template<typename Derived, unsigned int Mode>
template<typename RhsDerived>
diff --git a/test/triangular.cpp b/test/triangular.cpp
index 0c03e987e..7c680a8ed 100644
--- a/test/triangular.cpp
+++ b/test/triangular.cpp
@@ -86,7 +86,17 @@ template<typename MatrixType> void triangular(const MatrixType& m)
while (ei_abs2(m1(i,i))<1e-3) m1(i,i) = ei_random<Scalar>();
Transpose<MatrixType> trm4(m4);
- // test back and forward subsitution
+ // test back and forward subsitution with a vector as the rhs
+ m3 = m1.template triangularView<Eigen::UpperTriangular>();
+ VERIFY(v2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
+ m3 = m1.template triangularView<Eigen::LowerTriangular>();
+ VERIFY(v2.isApprox(m3.transpose() * (m1.transpose().template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
+ m3 = m1.template triangularView<Eigen::UpperTriangular>();
+ VERIFY(v2.isApprox(m3 * (m1.template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
+ m3 = m1.template triangularView<Eigen::LowerTriangular>();
+ VERIFY(v2.isApprox(m3.conjugate() * (m1.conjugate().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
+
+ // test back and forward subsitution with a matrix as the rhs
m3 = m1.template triangularView<Eigen::UpperTriangular>();
VERIFY(m2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(m2)), largerEps));
m3 = m1.template triangularView<Eigen::LowerTriangular>();