aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2009-07-31 13:18:19 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2009-07-31 13:18:19 +0200
commita156f5a8696a22a2738ab54bcc77a7565d5e9019 (patch)
tree8c4c2bf99c6ce1a90d5a87ae914016679423f07e
parentff20a2ba9423161efd2796832b896f9a4a1a3006 (diff)
faster trsm kernel and fix a couple of issues
-rw-r--r--Eigen/src/Core/SolveTriangular.h4
-rw-r--r--Eigen/src/Core/products/TriangularSolverMatrix.h54
2 files changed, 22 insertions, 36 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h
index f60ef1c03..9b67dd580 100644
--- a/Eigen/src/Core/SolveTriangular.h
+++ b/Eigen/src/Core/SolveTriangular.h
@@ -49,7 +49,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor
{
static const int PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
-
+
const int size = lhs.cols();
for(int pi=IsLowerTriangular ? 0 : size;
IsLowerTriangular ? pi<size : pi>0;
@@ -224,7 +224,7 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
ei_assert(!(Mode & ZeroDiagBit));
ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit));
- enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit };
+ enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit && RhsDerived::IsVectorAtCompileTime };
typedef typename ei_meta_if<copy,
typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy;
RhsCopy rhsCopy(rhs);
diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h
index 7842f8703..e49fac956 100644
--- a/Eigen/src/Core/products/TriangularSolverMatrix.h
+++ b/Eigen/src/Core/products/TriangularSolverMatrix.h
@@ -25,7 +25,7 @@
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
-// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix
+// if the rhs is row major, let's transpose the product
template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder>
struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
{
@@ -34,22 +34,16 @@ struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,Row
const Scalar* tri, int triStride,
Scalar* _other, int otherStride)
{
-
ei_triangular_solve_matrix<
Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft,
- (Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular,
- !Conjugate, TriStorageOrder, ColMajor>
+ (Mode&UnitDiagBit) | ((Mode&UpperTriangular) ? LowerTriangular : UpperTriangular),
+ NumTraits<Scalar>::IsComplex && Conjugate,
+ TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
::run(size, cols, tri, triStride, _other, otherStride);
-
-// Map<Matrix<Scalar,Dynamic,Dynamic> > other(_other, otherStride, cols);
-// Matrix<Scalar,Dynamic,Dynamic> aux = other.block(0,0,size,cols);
-// ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,ColMajor>
-// ::run(size, cols, tri, triStride, aux.data(), aux.stride());
-// other.block(0,0,size,cols) = aux;
}
};
-/* Optimized triangular solver with multiple right hand side (_TRSM)
+/* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
*/
template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
@@ -190,11 +184,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
Scalar* _other, int otherStride)
{
int rows = otherSize;
-// ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
-// ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
-
- Map<Matrix<Scalar,Dynamic,Dynamic,TriStorageOrder> > rhs(_tri,size,size);
- Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor> > lhs(_other,rows,size);
+ ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
+ ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
typedef ei_product_blocking_traits<Scalar> Blocking;
enum {
@@ -203,8 +194,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
};
- int kc = std::min<int>(/*Blocking::Max_kc/4*/32,size); // cache block size along the K direction
- int mc = std::min<int>(/*Blocking::Max_mc*/32,size); // cache block size along the M direction
+ int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction
+ int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize);
@@ -214,7 +205,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel;
- ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false> pack_lhs;
for(int k2=IsLowerTriangular ? size : 0;
IsLowerTriangular ? k2>0 : k2<size;
@@ -224,7 +214,7 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
int actual_k2 = IsLowerTriangular ? k2-actual_kc : k2 ;
int startPanel = IsLowerTriangular ? 0 : k2+actual_kc;
- int rs = IsLowerTriangular ? actual_k2 : size - actual_k2 - actual_kc;
+ int rs = IsLowerTriangular ? actual_k2 : size - actual_k2 - actual_kc;
Scalar* geb = blockB+actual_kc*actual_kc*Blocking::PacketSize;
if (rs>0) pack_rhs(geb, &rhs(actual_k2,startPanel), triStride, -1, actual_kc, rs);
@@ -239,8 +229,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
-// std::cerr << "$ " << k2 << " " << j2 << " " << actual_j2 << " " << panelOffset << " " << panelLength << "\n";
-
if (panelLength>0)
pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
&rhs(actual_k2+panelOffset, actual_j2), triStride, -1,
@@ -269,7 +257,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2;
// GEBP
- //if (lengthTarget>0)
if(panelLength>0)
{
gebp_kernel(&lhs(i2,absolute_j2), otherStride,
@@ -284,18 +271,17 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
{
int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
- Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
- for (int i=0; i<actual_mc; ++i)
+ Scalar* r = &lhs(i2,j);
+ for (int k3=0; k3<k; ++k3)
{
- int absolute_i = i2+i;
- Scalar b = 0;
- for (int k3=0; k3<k; ++k3)
- if(IsLowerTriangular)
- b += lhs(absolute_i,j+1+k3) * conj(rhs(j+1+k3,j));
- else
- b += lhs(absolute_i,absolute_j2+k3) * conj(rhs(absolute_j2+k3,j));
- lhs(absolute_i,j) = (lhs(absolute_i,j) - b)*a;
+ Scalar b = conj(rhs(IsLowerTriangular ? j+1+k3 : absolute_j2+k3,j));
+ Scalar* a = &lhs(i2,IsLowerTriangular ? j+1+k3 : absolute_j2+k3);
+ for (int i=0; i<actual_mc; ++i)
+ r[i] -= a[i] * b;
}
+ Scalar b = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
+ for (int i=0; i<actual_mc; ++i)
+ r[i] *= b;
}
// pack the just computed part of lhs to A
@@ -304,7 +290,7 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
actual_kc, j2);
}
}
-
+
if (rs>0)
gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb,
actual_mc, actual_kc, rs);