aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2009-07-26 13:01:37 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2009-07-26 13:01:37 +0200
commitf3fde74695eff236fe24b05ffb053d3890346420 (patch)
treee67b2323c788de9cb55686499c8d08c8ca885a1b
parent282e18da4915943b0dc2ed0140cfd037aaa02a70 (diff)
finalize trsm: works in all situations, and it is now used by solve() and solveInPlace()
-rw-r--r--Eigen/src/Core/SolveTriangular.h31
-rw-r--r--Eigen/src/Core/products/TriangularSolverMatrix.h111
-rw-r--r--test/CMakeLists.txt1
-rw-r--r--test/product_trsm.cpp95
4 files changed, 185 insertions, 53 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h
index cb162ca91..d0656eacb 100644
--- a/Eigen/src/Core/SolveTriangular.h
+++ b/Eigen/src/Core/SolveTriangular.h
@@ -29,13 +29,14 @@ template<typename Lhs, typename Rhs,
int Mode, // can be Upper/Lower | UnitDiag
int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME
? CompleteUnrolling : NoUnrolling,
- int StorageOrder = int(Lhs::Flags) & RowMajorBit
+ int StorageOrder = int(Lhs::Flags) & RowMajorBit,
+ int RhsCols = Rhs::ColsAtCompileTime
>
struct ei_triangular_solver_selector;
-// forward and backward substitution, row-major
+// forward and backward substitution, row-major, rhs is a vector
template<typename Lhs, typename Rhs, int Mode>
-struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
+struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor,1>
{
typedef typename Rhs::Scalar Scalar;
typedef ei_blas_traits<Lhs> LhsProductTraits;
@@ -89,9 +90,9 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
}
};
-// forward and backward substitution, column-major
+// forward and backward substitution, column-major, rhs is a vector
template<typename Lhs, typename Rhs, int Mode>
-struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
+struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor,1>
{
typedef typename Rhs::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type Packet;
@@ -150,6 +151,24 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
}
};
+template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, int Mode>
+struct ei_triangular_solve_matrix;
+
+// the rhs is a matrix
+template<typename Lhs, typename Rhs, int Mode, int StorageOrder, int RhsCols>
+struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,StorageOrder,RhsCols>
+{
+ typedef typename Rhs::Scalar Scalar;
+ typedef ei_blas_traits<Lhs> LhsProductTraits;
+ typedef typename LhsProductTraits::ActualXprType ActualLhsType;
+ static void run(const Lhs& lhs, Rhs& rhs)
+ {
+ const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
+ ei_triangular_solve_matrix<Scalar,StorageOrder,LhsProductTraits::NeedToConjugate,Rhs::Flags&RowMajorBit,Mode>
+ ::run(lhs.rows(), rhs.cols(), &actualLhs.coeff(0,0), actualLhs.stride(), &rhs.coeffRef(0,0), rhs.stride());
+ }
+};
+
/***************************************************************************
* meta-unrolling implementation
***************************************************************************/
@@ -184,7 +203,7 @@ struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
};
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
-struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> {
+struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder,1> {
static void run(const Lhs& lhs, Rhs& rhs)
{ ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
};
diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h
index eeb445f0b..550076f68 100644
--- a/Eigen/src/Core/products/TriangularSolverMatrix.h
+++ b/Eigen/src/Core/products/TriangularSolverMatrix.h
@@ -26,63 +26,37 @@
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
template<typename Scalar, int nr>
-struct ei_gemm_pack_rhs_panel
+struct ei_gemm_pack_rhs_panel;
+
+// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix
+template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
+struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,RowMajor,Mode>
{
- enum { PacketSize = ei_packet_traits<Scalar>::size };
- void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset)
+ static EIGEN_DONT_INLINE void run(
+ int size, int cols,
+ const Scalar* lhs, int lhsStride,
+ Scalar* _rhs, int rhsStride)
{
- int packet_cols = (cols/nr) * nr;
- int count = 0;
- for(int j2=0; j2<packet_cols; j2+=nr)
- {
- // skip what we have before
- count += PacketSize * nr * offset;
- const Scalar* b0 = &rhs[(j2+0)*rhsStride];
- const Scalar* b1 = &rhs[(j2+1)*rhsStride];
- const Scalar* b2 = &rhs[(j2+2)*rhsStride];
- const Scalar* b3 = &rhs[(j2+3)*rhsStride];
- for(int k=0; k<depth; k++)
- {
- ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[k]));
- ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b1[k]));
- if(nr==4) ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b2[k]));
- if(nr==4) ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b3[k]));
- count += nr*PacketSize;
- }
- // skip what we have after
- count += PacketSize * nr * (stride-offset-depth);
- }
- // copy the remaining columns one at a time (nr==1)
- for(int j2=packet_cols; j2<cols; ++j2)
- {
- count += PacketSize * offset;
- const Scalar* b0 = &rhs[(j2+0)*rhsStride];
- for(int k=0; k<depth; k++)
- {
- ei_pstore(&blockB[count], ei_pset1(alpha*b0[k]));
- count += PacketSize;
- }
- count += PacketSize * (stride-offset-depth);
- }
+ Map<Matrix<Scalar,Dynamic,Dynamic> > rhs(_rhs, rhsStride, cols);
+ Matrix<Scalar,Dynamic,Dynamic> aux = rhs.block(0,0,size,cols);
+ ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
+ ::run(size, cols, lhs, lhsStride, aux.data(), aux.stride());
+ rhs.block(0,0,size,cols) = aux;
}
};
/* Optimized triangular solver with multiple right hand side (_TRSM)
*/
-template <typename Scalar,
- int LhsStorageOrder,
- int RhsStorageOrder,
- int Mode>
-struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
+template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
+struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
{
-
static EIGEN_DONT_INLINE void run(
int size, int cols,
const Scalar* _lhs, int lhsStride,
Scalar* _rhs, int rhsStride)
{
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
- ei_blas_data_mapper <Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
+ ei_blas_data_mapper <Scalar, ColMajor> rhs(_rhs,rhsStride);
typedef ei_product_blocking_traits<Scalar> Blocking;
enum {
@@ -96,7 +70,8 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
- ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<false,false> > gebp_kernel;
+ ei_conj_if<ConjugateLhs> conj;
+ ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,false> > gebp_kernel;
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs;
for(int k2=IsLowerTriangular ? 0 : size;
@@ -131,7 +106,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
int s = IsLowerTriangular ? k2+k1 : i+1;
int rs = actualPanelWidth - k - 1; // remaining size
- Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/lhs(i,i);
+ Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(lhs(i,i));
for (int j=0; j<cols; ++j)
{
if (LhsStorageOrder==RowMajor)
@@ -140,7 +115,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
const Scalar* l = &lhs(i,s);
Scalar* r = &rhs(s,j);
for (int i3=0; i3<k; ++i3)
- b += l[i3] * r[i3];
+ b += conj(l[i3]) * r[i3];
rhs(i,j) = (rhs(i,j) - b)*a;
}
@@ -151,7 +126,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
Scalar* r = &rhs(s,j);
const Scalar* l = &lhs(s,i);
for (int i3=0;i3<rs;++i3)
- r[i3] -= b * l[i3];
+ r[i3] -= b * conj(l[i3]);
}
}
}
@@ -199,4 +174,46 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
}
};
+template<typename Scalar, int nr>
+struct ei_gemm_pack_rhs_panel
+{
+ enum { PacketSize = ei_packet_traits<Scalar>::size };
+ void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset)
+ {
+ int packet_cols = (cols/nr) * nr;
+ int count = 0;
+ for(int j2=0; j2<packet_cols; j2+=nr)
+ {
+ // skip what we have before
+ count += PacketSize * nr * offset;
+ const Scalar* b0 = &rhs[(j2+0)*rhsStride];
+ const Scalar* b1 = &rhs[(j2+1)*rhsStride];
+ const Scalar* b2 = &rhs[(j2+2)*rhsStride];
+ const Scalar* b3 = &rhs[(j2+3)*rhsStride];
+ for(int k=0; k<depth; k++)
+ {
+ ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[k]));
+ ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b1[k]));
+ if(nr==4) ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b2[k]));
+ if(nr==4) ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b3[k]));
+ count += nr*PacketSize;
+ }
+ // skip what we have after
+ count += PacketSize * nr * (stride-offset-depth);
+ }
+ // copy the remaining columns one at a time (nr==1)
+ for(int j2=packet_cols; j2<cols; ++j2)
+ {
+ count += PacketSize * offset;
+ const Scalar* b0 = &rhs[(j2+0)*rhsStride];
+ for(int k=0; k<depth; k++)
+ {
+ ei_pstore(&blockB[count], ei_pset1(alpha*b0[k]));
+ count += PacketSize;
+ }
+ count += PacketSize * (stride-offset-depth);
+ }
+ }
+};
+
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index d1c4d49e2..462032453 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -101,6 +101,7 @@ ei_add_test(product_extra ${EI_OFLAG})
ei_add_test(product_selfadjoint ${EI_OFLAG})
ei_add_test(product_symm ${EI_OFLAG})
ei_add_test(product_syrk ${EI_OFLAG})
+ei_add_test(product_trsm ${EI_OFLAG})
ei_add_test(diagonalmatrices)
ei_add_test(adjoint)
ei_add_test(submatrices)
diff --git a/test/product_trsm.cpp b/test/product_trsm.cpp
new file mode 100644
index 000000000..80df5861e
--- /dev/null
+++ b/test/product_trsm.cpp
@@ -0,0 +1,95 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@gmail.com>
+//
+// Eigen is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 3 of the License, or (at your option) any later version.
+//
+// Alternatively, you can redistribute it and/or
+// modify it under the terms of the GNU General Public License as
+// published by the Free Software Foundation; either version 2 of
+// the License, or (at your option) any later version.
+//
+// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
+// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License and a copy of the GNU General Public License along with
+// Eigen. If not, see <http://www.gnu.org/licenses/>.
+
+#include "main.h"
+
+template<typename Lhs, typename Rhs>
+void solve_ref(const Lhs& lhs, Rhs& rhs)
+{
+ for (int j=0; j<rhs.cols(); ++j)
+ lhs.solveInPlace(rhs.col(j));
+}
+
+template<typename Scalar> void trsm(int size,int cols)
+{
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+
+ Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmLhs(size,size);
+ Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmLhs(size,size);
+
+ Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmRef(size,cols), cmRhs(size,cols);
+ Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmRef(size,cols), rmRhs(size,cols);
+
+ cmLhs.setRandom(); cmLhs.diagonal().cwise() += 10;
+ rmLhs.setRandom(); rmLhs.diagonal().cwise() += 10;
+
+ cmRhs.setRandom(); cmRef = cmRhs;
+ cmLhs.conjugate().template triangularView<LowerTriangular>().solveInPlace(cmRhs);
+ solve_ref(cmLhs.conjugate().template triangularView<LowerTriangular>(),cmRef);
+ VERIFY_IS_APPROX(cmRhs, cmRef);
+
+ cmRhs.setRandom(); cmRef = cmRhs;
+ cmLhs.conjugate().template triangularView<UpperTriangular>().solveInPlace(cmRhs);
+ solve_ref(cmLhs.conjugate().template triangularView<UpperTriangular>(),cmRef);
+ VERIFY_IS_APPROX(cmRhs, cmRef);
+
+ rmRhs.setRandom(); rmRef = rmRhs;
+ cmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs);
+ solve_ref(cmLhs.template triangularView<LowerTriangular>(),rmRef);
+ VERIFY_IS_APPROX(rmRhs, rmRef);
+
+ rmRhs.setRandom(); rmRef = rmRhs;
+ cmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs);
+ solve_ref(cmLhs.template triangularView<UpperTriangular>(),rmRef);
+ VERIFY_IS_APPROX(rmRhs, rmRef);
+
+
+ cmRhs.setRandom(); cmRef = cmRhs;
+ rmLhs.template triangularView<UnitLowerTriangular>().solveInPlace(cmRhs);
+ solve_ref(rmLhs.template triangularView<UnitLowerTriangular>(),cmRef);
+ VERIFY_IS_APPROX(cmRhs, cmRef);
+
+ cmRhs.setRandom(); cmRef = cmRhs;
+ rmLhs.template triangularView<UnitUpperTriangular>().solveInPlace(cmRhs);
+ solve_ref(rmLhs.template triangularView<UnitUpperTriangular>(),cmRef);
+ VERIFY_IS_APPROX(cmRhs, cmRef);
+
+ rmRhs.setRandom(); rmRef = rmRhs;
+ rmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs);
+ solve_ref(rmLhs.template triangularView<LowerTriangular>(),rmRef);
+ VERIFY_IS_APPROX(rmRhs, rmRef);
+
+ rmRhs.setRandom(); rmRef = rmRhs;
+ rmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs);
+ solve_ref(rmLhs.template triangularView<UpperTriangular>(),rmRef);
+ VERIFY_IS_APPROX(rmRhs, rmRef);
+}
+void test_product_trsm()
+{
+ for(int i = 0; i < g_repeat ; i++)
+ {
+ trsm<float>(ei_random<int>(1,320),ei_random<int>(1,320));
+ trsm<std::complex<double> >(ei_random<int>(1,320),ei_random<int>(1,320));
+ }
+}