diff options
-rw-r--r-- | Eigen/Core | 1 | ||||
-rw-r--r-- | Eigen/src/Core/SolveTriangular.h | 124 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularSolverVector.h | 138 | ||||
-rw-r--r-- | test/product_trsolve.cpp | 5 |
4 files changed, 174 insertions, 94 deletions
diff --git a/Eigen/Core b/Eigen/Core index 4401a4e14..5e81d10d1 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -315,6 +315,7 @@ using std::size_t; #include "src/Core/products/TriangularMatrixVector.h" #include "src/Core/products/TriangularMatrixMatrix.h" #include "src/Core/products/TriangularSolverMatrix.h" +#include "src/Core/products/TriangularSolverVector.h" #include "src/Core/BandMatrix.h" #include "src/Core/BooleanRedux.h" diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index abbf57553..b950d2c31 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -27,6 +27,15 @@ namespace internal { +// Forward declarations: +// The following two routines are implemented in the products/TriangularSolver*.h files +template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder> +struct triangular_solve_vector; + +template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> +struct triangular_solve_matrix; + +// small helper struct extracting some traits on the underlying solver operation template<typename Lhs, typename Rhs, int Side> class trsolve_traits { @@ -51,111 +60,40 @@ template<typename Lhs, typename Rhs, > struct triangular_solver_selector; -// forward and backward substitution, row-major, rhs is a vector -template<typename Lhs, typename Rhs, int Mode> -struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor,1> +template<typename Lhs, typename Rhs, int Mode, int StorageOrder> +struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1> { typedef typename Lhs::Scalar LhsScalar; typedef typename Rhs::Scalar RhsScalar; typedef blas_traits<Lhs> LhsProductTraits; typedef typename LhsProductTraits::ExtractType ActualLhsType; - typedef typename Lhs::Index Index; - enum { - IsLower = ((Mode&Lower)==Lower) - }; - static void run(const Lhs& lhs, Rhs& other) + typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs; + static void run(const Lhs& lhs, Rhs& rhs) { - static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; ActualLhsType actualLhs = LhsProductTraits::extract(lhs); - const Index size = lhs.cols(); - for(Index pi=IsLower ? 0 : size; - IsLower ? pi<size : pi>0; - IsLower ? pi+=PanelWidth : pi-=PanelWidth) - { - Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth); - - Index r = IsLower ? pi : size - pi; // remaining size - if (r > 0) - { - // let's directly call the low level product function because: - // 1 - it is faster to compile - // 2 - it is slighlty faster at runtime - Index startRow = IsLower ? pi : pi-actualPanelWidth; - Index startCol = IsLower ? 0 : pi; - - general_matrix_vector_product<Index,LhsScalar,RowMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run( - actualPanelWidth, r, - &(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.outerStride(), - &(other.coeffRef(startCol)), other.innerStride(), - &other.coeffRef(startRow), other.innerStride(), - RhsScalar(-1)); - } - - for(Index k=0; k<actualPanelWidth; ++k) - { - Index i = IsLower ? pi+k : pi-k-1; - Index s = IsLower ? pi : i+1; - if (k>0) - other.coeffRef(i) -= (lhs.row(i).segment(s,k).transpose().cwiseProduct(other.segment(s,k))).sum(); + // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1 - if(!(Mode & UnitDiag)) - other.coeffRef(i) /= lhs.coeff(i,i); - } + bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1; + RhsScalar* actualRhs; + if(useRhsDirectly) + { + actualRhs = &rhs.coeffRef(0); } - } -}; - -// forward and backward substitution, column-major, rhs is a vector -template<typename Lhs, typename Rhs, int Mode> -struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor,1> -{ - typedef typename Lhs::Scalar LhsScalar; - typedef typename Rhs::Scalar RhsScalar; - typedef blas_traits<Lhs> LhsProductTraits; - typedef typename LhsProductTraits::ExtractType ActualLhsType; - typedef typename Lhs::Index Index; - enum { - IsLower = ((Mode&Lower)==Lower) - }; - - static void run(const Lhs& lhs, Rhs& other) - { - static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; - ActualLhsType actualLhs = LhsProductTraits::extract(lhs); - - const Index size = lhs.cols(); - for(Index pi=IsLower ? 0 : size; - IsLower ? pi<size : pi>0; - IsLower ? pi+=PanelWidth : pi-=PanelWidth) + else { - Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth); - Index startBlock = IsLower ? pi : pi-actualPanelWidth; - Index endBlock = IsLower ? pi + actualPanelWidth : 0; + actualRhs = ei_aligned_stack_new(RhsScalar,rhs.size()); + MappedRhs(actualRhs,rhs.size()) = rhs; + } - for(Index k=0; k<actualPanelWidth; ++k) - { - Index i = IsLower ? pi+k : pi-k-1; - if(!(Mode & UnitDiag)) - other.coeffRef(i) /= lhs.coeff(i,i); + + triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder> + ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs); - Index r = actualPanelWidth - k - 1; // remaining size - Index s = IsLower ? i+1 : i-r; - if (r>0) - other.segment(s,r) -= other.coeffRef(i) * Block<Lhs,Dynamic,1>(lhs, s, i, r, 1); - } - Index r = IsLower ? size - endBlock : startBlock; // remaining size - if (r > 0) - { - // let's directly call the low level product function because: - // 1 - it is faster to compile - // 2 - it is slighlty faster at runtime - general_matrix_vector_product<Index,LhsScalar,ColMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run( - r, actualPanelWidth, - &(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.outerStride(), - &other.coeff(startBlock), other.innerStride(), - &(other.coeffRef(endBlock, 0)), other.innerStride(), RhsScalar(-1)); - } + if(!useRhsDirectly) + { + rhs = MappedRhs(actualRhs, rhs.size()); + ei_aligned_stack_delete(RhsScalar, actualRhs, rhs.size()); } } }; @@ -172,8 +110,6 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder } }; -template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> -struct triangular_solve_matrix; // the rhs is a matrix template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder> diff --git a/Eigen/src/Core/products/TriangularSolverVector.h b/Eigen/src/Core/products/TriangularSolverVector.h new file mode 100644 index 000000000..fcf8bcae0 --- /dev/null +++ b/Eigen/src/Core/products/TriangularSolverVector.h @@ -0,0 +1,138 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr> +// +// Eigen is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 3 of the License, or (at your option) any later version. +// +// Alternatively, you can redistribute it and/or +// modify it under the terms of the GNU General Public License as +// published by the Free Software Foundation; either version 2 of +// the License, or (at your option) any later version. +// +// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License and a copy of the GNU General Public License along with +// Eigen. If not, see <http://www.gnu.org/licenses/>. + +#ifndef EIGEN_TRIANGULAR_SOLVER_VECTOR_H +#define EIGEN_TRIANGULAR_SOLVER_VECTOR_H + +namespace internal { + +// forward and backward substitution, row-major, rhs is a vector +template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate> +struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, RowMajor> +{ + enum { + IsLower = ((Mode&Lower)==Lower) + }; + static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs) + { + typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap; + const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride)); + typename internal::conditional< + Conjugate, + const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>, + const LhsMap&> + ::type cjLhs(lhs); + static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; + for(Index pi=IsLower ? 0 : size; + IsLower ? pi<size : pi>0; + IsLower ? pi+=PanelWidth : pi-=PanelWidth) + { + Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth); + + Index r = IsLower ? pi : size - pi; // remaining size + if (r > 0) + { + // let's directly call the low level product function because: + // 1 - it is faster to compile + // 2 - it is slighlty faster at runtime + Index startRow = IsLower ? pi : pi-actualPanelWidth; + Index startCol = IsLower ? 0 : pi; + + general_matrix_vector_product<Index,LhsScalar,RowMajor,Conjugate,RhsScalar,false>::run( + actualPanelWidth, r, + &(lhs.coeff(startRow,startCol)), lhsStride, + rhs + startCol, 1, + rhs + startRow, 1, + RhsScalar(-1)); + } + + for(Index k=0; k<actualPanelWidth; ++k) + { + Index i = IsLower ? pi+k : pi-k-1; + Index s = IsLower ? pi : i+1; + if (k>0) + rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,k))).sum(); + + if(!(Mode & UnitDiag)) + rhs[i] /= lhs(i,i); + } + } + } +}; + +// forward and backward substitution, column-major, rhs is a vector +template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate> +struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, ColMajor> +{ + enum { + IsLower = ((Mode&Lower)==Lower) + }; + static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs) + { + typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap; + const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride)); + typename internal::conditional<Conjugate, + const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>, + const LhsMap& + >::type cjLhs(lhs); + static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; + + for(Index pi=IsLower ? 0 : size; + IsLower ? pi<size : pi>0; + IsLower ? pi+=PanelWidth : pi-=PanelWidth) + { + Index actualPanelWidth = std::min(IsLower ? size - pi : pi, PanelWidth); + Index startBlock = IsLower ? pi : pi-actualPanelWidth; + Index endBlock = IsLower ? pi + actualPanelWidth : 0; + + for(Index k=0; k<actualPanelWidth; ++k) + { + Index i = IsLower ? pi+k : pi-k-1; + if(!(Mode & UnitDiag)) + rhs[i] /= cjLhs.coeff(i,i); + + Index r = actualPanelWidth - k - 1; // remaining size + Index s = IsLower ? i+1 : i-r; + if (r>0) + Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,r) -= rhs[i] * cjLhs.col(i).segment(s,r); + } + Index r = IsLower ? size - endBlock : startBlock; // remaining size + if (r > 0) + { + // let's directly call the low level product function because: + // 1 - it is faster to compile + // 2 - it is slighlty faster at runtime + general_matrix_vector_product<Index,LhsScalar,ColMajor,Conjugate,RhsScalar,false>::run( + r, actualPanelWidth, + &(lhs.coeff(endBlock,startBlock)), lhsStride, + rhs+startBlock, 1, + rhs+endBlock, 1, RhsScalar(-1)); + } + } + } +}; + +} // end namespace internal + +#endif // EIGEN_TRIANGULAR_SOLVER_VECTOR_H diff --git a/test/product_trsolve.cpp b/test/product_trsolve.cpp index e7ada23a5..50aa37d45 100644 --- a/test/product_trsolve.cpp +++ b/test/product_trsolve.cpp @@ -73,6 +73,10 @@ template<typename Scalar,int Size, int Cols> void trsolve(int size=Size,int cols VERIFY_TRSM_ONTHERIGHT(rmLhs .template triangularView<Lower>(), cmRhs); VERIFY_TRSM_ONTHERIGHT(rmLhs.conjugate().template triangularView<UnitUpper>(), rmRhs); + + int c = internal::random<int>(0,cols-1); + VERIFY_TRSM(rmLhs.template triangularView<Lower>(), rmRhs.col(c)); + VERIFY_TRSM(cmLhs.template triangularView<Lower>(), rmRhs.col(c)); } void test_product_trsolve() @@ -86,6 +90,7 @@ void test_product_trsolve() CALL_SUBTEST_4((trsolve<std::complex<double>,Dynamic,Dynamic>(internal::random<int>(1,200),internal::random<int>(1,200)))); // vectors + CALL_SUBTEST_1((trsolve<float,Dynamic,1>(internal::random<int>(1,320)))); CALL_SUBTEST_5((trsolve<std::complex<double>,Dynamic,1>(internal::random<int>(1,320)))); CALL_SUBTEST_6((trsolve<float,1,1>())); CALL_SUBTEST_7((trsolve<float,1,2>())); |