diff options
-rw-r--r-- | Eigen/src/Core/SolveTriangular.h | 20 | ||||
-rw-r--r-- | test/CMakeLists.txt | 2 | ||||
-rw-r--r-- | test/product_trsolve.cpp (renamed from test/product_trsm.cpp) | 22 |
3 files changed, 31 insertions, 13 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index c7f0cd227..e8230dd50 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -31,7 +31,7 @@ template<typename Lhs, typename Rhs, int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME ? CompleteUnrolling : NoUnrolling, int StorageOrder = (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor, - int RhsCols = Rhs::ColsAtCompileTime + int RhsVectors = Rhs::IsVectorAtCompileTime ? 1 : Dynamic > struct ei_triangular_solver_selector; @@ -143,18 +143,30 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor } }; +// transpose OnTheRight cases for vectors +template<typename Lhs, typename Rhs, int Mode, int Unrolling, int StorageOrder> +struct ei_triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder,1> +{ + static void run(const Lhs& lhs, Rhs& rhs) + { + Transpose<Rhs> rhsTr(rhs); + Transpose<Lhs> lhsTr(lhs); + ei_triangular_solver_selector<Transpose<Lhs>,Transpose<Rhs>,OnTheLeft,TriangularView<Lhs,Mode>::TransposeMode>::run(lhsTr,rhsTr); + } +}; + template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> struct ei_triangular_solve_matrix; // the rhs is a matrix -template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder, int RhsCols> -struct ei_triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,RhsCols> +template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder> +struct ei_triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dynamic> { typedef typename Rhs::Scalar Scalar; typedef ei_blas_traits<Lhs> LhsProductTraits; typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; static void run(const Lhs& lhs, Rhs& rhs) - { + {std::cerr << "mat\n"; const ActualLhsType actualLhs = LhsProductTraits::extract(lhs); ei_triangular_solve_matrix<Scalar,Side,Mode,LhsProductTraits::NeedToConjugate,StorageOrder, (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor> diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ffe89915a..b8efbcf51 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -116,7 +116,7 @@ ei_add_test(product_symm) ei_add_test(product_syrk) ei_add_test(product_trmv) ei_add_test(product_trmm) -ei_add_test(product_trsm) +ei_add_test(product_trsolve) ei_add_test(product_notemporary) ei_add_test(stable_norm) ei_add_test(bandmatrix) diff --git a/test/product_trsm.cpp b/test/product_trsolve.cpp index 1103e79a9..449240f7c 100644 --- a/test/product_trsm.cpp +++ b/test/product_trsolve.cpp @@ -36,15 +36,15 @@ VERIFY_IS_APPROX((XB).transpose() * (TRI).transpose().toDenseMatrix(), ref.transpose()); \ } -template<typename Scalar> void trsm(int size,int cols) +template<typename Scalar,int Size, int Cols> void trsolve(int size=Size,int cols=Cols) { typedef typename NumTraits<Scalar>::Real RealScalar; - Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmLhs(size,size); - Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmLhs(size,size); + Matrix<Scalar,Size,Size,ColMajor> cmLhs(size,size); + Matrix<Scalar,Size,Size,RowMajor> rmLhs(size,size); - Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmRhs(size,cols), ref(size,cols); - Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmRhs(size,cols); + Matrix<Scalar,Size,Cols,ColMajor> cmRhs(size,cols), ref(size,cols); + Matrix<Scalar,Size,Cols,RowMajor> rmRhs(size,cols); cmLhs.setRandom(); cmLhs *= static_cast<RealScalar>(0.1); cmLhs.diagonal().cwise() += static_cast<RealScalar>(1); rmLhs.setRandom(); rmLhs *= static_cast<RealScalar>(0.1); rmLhs.diagonal().cwise() += static_cast<RealScalar>(1); @@ -73,11 +73,17 @@ template<typename Scalar> void trsm(int size,int cols) VERIFY_TRSM_ONTHERIGHT(rmLhs.conjugate().template triangularView<UnitUpperTriangular>(), rmRhs); } -void test_product_trsm() +void test_product_trsolve() { for(int i = 0; i < g_repeat ; i++) { - CALL_SUBTEST_1((trsm<float>(ei_random<int>(1,320),ei_random<int>(1,320)))); - CALL_SUBTEST_2((trsm<std::complex<double> >(ei_random<int>(1,320),ei_random<int>(1,320)))); + // matrices + CALL_SUBTEST_1((trsolve<float,Dynamic,Dynamic>(ei_random<int>(1,320),ei_random<int>(1,320)))); + CALL_SUBTEST_2((trsolve<std::complex<double>,Dynamic,Dynamic>(ei_random<int>(1,320),ei_random<int>(1,320)))); + + // vectors + CALL_SUBTEST_3((trsolve<std::complex<double>,Dynamic,1>(ei_random<int>(1,320)))); + CALL_SUBTEST_4((trsolve<float,1,1>())); + CALL_SUBTEST_5((trsolve<std::complex<float>,4,1>())); } } |