aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products/TriangularSolverMatrix.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2009-07-30 16:03:06 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2009-07-30 16:03:06 +0200
commitff20a2ba9423161efd2796832b896f9a4a1a3006 (patch)
tree54b8f8d8c442d6c0d150db0770989e07434dd816 /Eigen/src/Core/products/TriangularSolverMatrix.h
parent62d9b9b7b51fa0249baf59db91bdfd9af191cdb3 (diff)
add explicit "on the right" triangular solving,
=> no temporary when the rhs/unknows is row major
Diffstat (limited to 'Eigen/src/Core/products/TriangularSolverMatrix.h')
-rw-r--r--Eigen/src/Core/products/TriangularSolverMatrix.h211
1 files changed, 178 insertions, 33 deletions
diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h
index e28aba747..7842f8703 100644
--- a/Eigen/src/Core/products/TriangularSolverMatrix.h
+++ b/Eigen/src/Core/products/TriangularSolverMatrix.h
@@ -26,34 +26,42 @@
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix
-template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
-struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,RowMajor,Mode>
+template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder>
+struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
{
static EIGEN_DONT_INLINE void run(
int size, int cols,
- const Scalar* lhs, int lhsStride,
- Scalar* _rhs, int rhsStride)
+ const Scalar* tri, int triStride,
+ Scalar* _other, int otherStride)
{
- Map<Matrix<Scalar,Dynamic,Dynamic> > rhs(_rhs, rhsStride, cols);
- Matrix<Scalar,Dynamic,Dynamic> aux = rhs.block(0,0,size,cols);
- ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
- ::run(size, cols, lhs, lhsStride, aux.data(), aux.stride());
- rhs.block(0,0,size,cols) = aux;
+
+ ei_triangular_solve_matrix<
+ Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft,
+ (Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular,
+ !Conjugate, TriStorageOrder, ColMajor>
+ ::run(size, cols, tri, triStride, _other, otherStride);
+
+// Map<Matrix<Scalar,Dynamic,Dynamic> > other(_other, otherStride, cols);
+// Matrix<Scalar,Dynamic,Dynamic> aux = other.block(0,0,size,cols);
+// ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,ColMajor>
+// ::run(size, cols, tri, triStride, aux.data(), aux.stride());
+// other.block(0,0,size,cols) = aux;
}
};
/* Optimized triangular solver with multiple right hand side (_TRSM)
*/
-template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
-struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
+template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
+struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
{
static EIGEN_DONT_INLINE void run(
- int size, int cols,
- const Scalar* _lhs, int lhsStride,
- Scalar* _rhs, int rhsStride)
+ int size, int otherSize,
+ const Scalar* _tri, int triStride,
+ Scalar* _other, int otherStride)
{
- ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
- ei_blas_data_mapper <Scalar, ColMajor> rhs(_rhs,rhsStride);
+ int cols = otherSize;
+ ei_const_blas_data_mapper<Scalar, TriStorageOrder> tri(_tri,triStride);
+ ei_blas_data_mapper<Scalar, ColMajor> other(_other,otherStride);
typedef ei_product_blocking_traits<Scalar> Blocking;
enum {
@@ -67,9 +75,9 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
- ei_conj_if<ConjugateLhs> conj;
- ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,false> > gebp_kernel;
- ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs;
+ ei_conj_if<Conjugate> conj;
+ ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<Conjugate,false> > gebp_kernel;
+ ei_gemm_pack_lhs<Scalar,Blocking::mr,TriStorageOrder> pack_lhs;
for(int k2=IsLowerTriangular ? 0 : size;
IsLowerTriangular ? k2<size : k2>0;
@@ -103,25 +111,25 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
int s = IsLowerTriangular ? k2+k1 : i+1;
int rs = actualPanelWidth - k - 1; // remaining size
- Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(lhs(i,i));
+ Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
for (int j=0; j<cols; ++j)
{
- if (LhsStorageOrder==RowMajor)
+ if (TriStorageOrder==RowMajor)
{
Scalar b = 0;
- const Scalar* l = &lhs(i,s);
- Scalar* r = &rhs(s,j);
+ const Scalar* l = &tri(i,s);
+ Scalar* r = &other(s,j);
for (int i3=0; i3<k; ++i3)
b += conj(l[i3]) * r[i3];
- rhs(i,j) = (rhs(i,j) - b)*a;
+ other(i,j) = (other(i,j) - b)*a;
}
else
{
int s = IsLowerTriangular ? i+1 : i-rs;
- Scalar b = (rhs(i,j) *= a);
- Scalar* r = &rhs(s,j);
- const Scalar* l = &lhs(s,i);
+ Scalar b = (other(i,j) *= a);
+ Scalar* r = &other(s,j);
+ const Scalar* l = &tri(s,i);
for (int i3=0;i3<rs;++i3)
r[i3] -= b * conj(l[i3]);
}
@@ -132,18 +140,18 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
int startBlock = IsLowerTriangular ? k2+k1 : k2-k1-actualPanelWidth;
int blockBOffset = IsLowerTriangular ? k1 : lengthTarget;
- // update the respective rows of B from rhs
+ // update the respective rows of B from other
ei_gemm_pack_rhs<Scalar, Blocking::nr, ColMajor, true>()
- (blockB, _rhs+startBlock, rhsStride, -1, actualPanelWidth, cols, actual_kc, blockBOffset);
+ (blockB, _other+startBlock, otherStride, -1, actualPanelWidth, cols, actual_kc, blockBOffset);
// GEBP
if (lengthTarget>0)
{
int startTarget = IsLowerTriangular ? k2+k1+actualPanelWidth : k2-actual_kc;
- pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget);
+ pack_lhs(blockA, &tri(startTarget,startBlock), triStride, actualPanelWidth, lengthTarget);
- gebp_kernel(_rhs+startTarget, rhsStride, blockA, blockB, lengthTarget, actualPanelWidth, cols,
+ gebp_kernel(_other+startTarget, otherStride, blockA, blockB, lengthTarget, actualPanelWidth, cols,
actualPanelWidth, actual_kc, 0, blockBOffset*Blocking::PacketSize);
}
}
@@ -158,9 +166,9 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
const int actual_mc = std::min(mc,end-i2);
if (actual_mc>0)
{
- pack_lhs(blockA, &lhs(i2, IsLowerTriangular ? k2 : k2-kc), lhsStride, actual_kc, actual_mc);
+ pack_lhs(blockA, &tri(i2, IsLowerTriangular ? k2 : k2-kc), triStride, actual_kc, actual_mc);
- gebp_kernel(_rhs+i2, rhsStride, blockA, blockB, actual_mc, actual_kc, cols);
+ gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols);
}
}
}
@@ -171,4 +179,141 @@ struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,M
}
};
+/* Optimized triangular solver with multiple left hand sides and the trinagular matrix on the right
+ */
+template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
+struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor>
+{
+ static EIGEN_DONT_INLINE void run(
+ int size, int otherSize,
+ const Scalar* _tri, int triStride,
+ Scalar* _other, int otherStride)
+ {
+ int rows = otherSize;
+// ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
+// ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
+
+ Map<Matrix<Scalar,Dynamic,Dynamic,TriStorageOrder> > rhs(_tri,size,size);
+ Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor> > lhs(_other,rows,size);
+
+ typedef ei_product_blocking_traits<Scalar> Blocking;
+ enum {
+ RhsStorageOrder = TriStorageOrder,
+ SmallPanelWidth = EIGEN_ENUM_MAX(Blocking::mr,Blocking::nr),
+ IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
+ };
+
+ int kc = std::min<int>(/*Blocking::Max_kc/4*/32,size); // cache block size along the K direction
+ int mc = std::min<int>(/*Blocking::Max_mc*/32,size); // cache block size along the M direction
+
+ Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
+ Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize);
+
+ ei_conj_if<Conjugate> conj;
+ ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<false,Conjugate> > gebp_kernel;
+ ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
+ ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
+ ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel;
+ ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false> pack_lhs;
+
+ for(int k2=IsLowerTriangular ? size : 0;
+ IsLowerTriangular ? k2>0 : k2<size;
+ IsLowerTriangular ? k2-=kc : k2+=kc)
+ {
+ const int actual_kc = std::min(IsLowerTriangular ? k2 : size-k2, kc);
+ int actual_k2 = IsLowerTriangular ? k2-actual_kc : k2 ;
+
+ int startPanel = IsLowerTriangular ? 0 : k2+actual_kc;
+ int rs = IsLowerTriangular ? actual_k2 : size - actual_k2 - actual_kc;
+ Scalar* geb = blockB+actual_kc*actual_kc*Blocking::PacketSize;
+
+ if (rs>0) pack_rhs(geb, &rhs(actual_k2,startPanel), triStride, -1, actual_kc, rs);
+
+ // triangular packing (we only pack the panels off the diagonal,
+ // neglecting the blocks overlapping the diagonal
+ {
+ for (int j2=0; j2<actual_kc; j2+=SmallPanelWidth)
+ {
+ int actualPanelWidth = std::min<int>(actual_kc-j2, SmallPanelWidth);
+ int actual_j2 = actual_k2 + j2;
+ int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
+ int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
+
+// std::cerr << "$ " << k2 << " " << j2 << " " << actual_j2 << " " << panelOffset << " " << panelLength << "\n";
+
+ if (panelLength>0)
+ pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
+ &rhs(actual_k2+panelOffset, actual_j2), triStride, -1,
+ panelLength, actualPanelWidth,
+ actual_kc, panelOffset);
+ }
+ }
+
+ for(int i2=0; i2<rows; i2+=mc)
+ {
+ const int actual_mc = std::min(mc,rows-i2);
+
+ // triangular solver kernel
+ {
+ // for each small block of the diagonal (=> vertical panels of rhs)
+ for (int j2 = IsLowerTriangular
+ ? (actual_kc - ((actual_kc%SmallPanelWidth) ? (actual_kc%SmallPanelWidth)
+ : SmallPanelWidth))
+ : 0;
+ IsLowerTriangular ? j2>=0 : j2<actual_kc;
+ IsLowerTriangular ? j2-=SmallPanelWidth : j2+=SmallPanelWidth)
+ {
+ int actualPanelWidth = std::min<int>(actual_kc-j2, SmallPanelWidth);
+ int absolute_j2 = actual_k2 + j2;
+ int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
+ int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2;
+
+ // GEBP
+ //if (lengthTarget>0)
+ if(panelLength>0)
+ {
+ gebp_kernel(&lhs(i2,absolute_j2), otherStride,
+ blockA, blockB+j2*actual_kc*Blocking::PacketSize,
+ actual_mc, panelLength, actualPanelWidth,
+ actual_kc, actual_kc, // strides
+ panelOffset, panelOffset*Blocking::PacketSize); // offsets
+ }
+
+ // unblocked triangular solve
+ for (int k=0; k<actualPanelWidth; ++k)
+ {
+ int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
+
+ Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
+ for (int i=0; i<actual_mc; ++i)
+ {
+ int absolute_i = i2+i;
+ Scalar b = 0;
+ for (int k3=0; k3<k; ++k3)
+ if(IsLowerTriangular)
+ b += lhs(absolute_i,j+1+k3) * conj(rhs(j+1+k3,j));
+ else
+ b += lhs(absolute_i,absolute_j2+k3) * conj(rhs(absolute_j2+k3,j));
+ lhs(absolute_i,j) = (lhs(absolute_i,j) - b)*a;
+ }
+ }
+
+ // pack the just computed part of lhs to A
+ pack_lhs_panel(blockA, _other+absolute_j2*otherStride+i2, otherStride,
+ actualPanelWidth, actual_mc,
+ actual_kc, j2);
+ }
+ }
+
+ if (rs>0)
+ gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb,
+ actual_mc, actual_kc, rs);
+ }
+ }
+
+ ei_aligned_stack_delete(Scalar, blockA, kc*mc);
+ ei_aligned_stack_delete(Scalar, blockB, kc*size*Blocking::PacketSize);
+ }
+};
+
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H