diff options
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixVector.h | 64 | ||||
-rw-r--r-- | test/product_trmv.cpp | 6 |
2 files changed, 62 insertions, 8 deletions
diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h index 1a2b183aa..935054a5a 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector.h +++ b/Eigen/src/Core/products/TriangularMatrixVector.h @@ -25,12 +25,26 @@ #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H #define EIGEN_TRIANGULARMATRIXVECTOR_H -template<typename MatrixType, typename Rhs, typename Result, +template<bool LhsIsTriangular, typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs, int StorageOrder> struct ei_product_triangular_vector_selector; +template<typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs, int StorageOrder> +struct ei_product_triangular_vector_selector<false,Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,StorageOrder> +{ + static EIGEN_DONT_INLINE void run(const Lhs& lhs, const Rhs& rhs, Result& res, typename ei_traits<Lhs>::Scalar alpha) + { + typedef Transpose<Rhs> TrRhs; TrRhs trRhs(rhs); + typedef Transpose<Lhs> TrLhs; TrLhs trLhs(lhs); + typedef Transpose<Result> TrRes; TrRes trRes(res); + ei_product_triangular_vector_selector<true,TrRhs,TrLhs,TrRes, + (Mode & UnitDiag) | (Mode & Lower) ? Upper : Lower, ConjRhs, ConjLhs, StorageOrder==RowMajor ? ColMajor : RowMajor> + ::run(trRhs,trLhs,trRes,alpha); + } +}; + template<typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs> -struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,ColMajor> +struct ei_product_triangular_vector_selector<true,Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,ColMajor> { typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Index Index; @@ -74,7 +88,7 @@ struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs }; template<typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs> -struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,RowMajor> +struct ei_product_triangular_vector_selector<true,Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,RowMajor> { typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Index Index; @@ -119,12 +133,17 @@ struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs * Wrapper to ei_product_triangular_vector ***************************************************************************/ -template<int Mode, /*bool LhsIsTriangular, */typename Lhs, typename Rhs> -struct ei_traits<TriangularProduct<Mode,true,Lhs,false,Rhs,true> > - : ei_traits<ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs> > +template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> +struct ei_traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> > + : ei_traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> > +{}; + +template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> +struct ei_traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> > + : ei_traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> > {}; -template<int Mode, /*bool LhsIsTriangular, */typename Lhs, typename Rhs> +template<int Mode, typename Lhs, typename Rhs> struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs > { @@ -143,7 +162,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> * RhsBlasTraits::extractScalarFactor(m_rhs); ei_product_triangular_vector_selector - <_ActualLhsType,_ActualRhsType,Dest, + <true,_ActualLhsType,_ActualRhsType,Dest, Mode, LhsBlasTraits::NeedToConjugate, RhsBlasTraits::NeedToConjugate, @@ -152,4 +171,33 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> } }; +template<int Mode, typename Lhs, typename Rhs> +struct TriangularProduct<Mode,false,Lhs,true,Rhs,false> + : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs > +{ + EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) + + TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} + + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const + { + + ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); + + const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); + const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) + * RhsBlasTraits::extractScalarFactor(m_rhs); + + ei_product_triangular_vector_selector + <false,_ActualLhsType,_ActualRhsType,Dest, + Mode, + LhsBlasTraits::NeedToConjugate, + RhsBlasTraits::NeedToConjugate, + (int(ei_traits<Rhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor> + ::run(lhs,rhs,dst,actualAlpha); + } +}; + #endif // EIGEN_TRIANGULARMATRIXVECTOR_H diff --git a/test/product_trmv.cpp b/test/product_trmv.cpp index f0962557a..2f5743187 100644 --- a/test/product_trmv.cpp +++ b/test/product_trmv.cpp @@ -76,6 +76,12 @@ template<typename MatrixType> void trmv(const MatrixType& m) VERIFY((m3.adjoint() * (s1*v1.conjugate())).isApprox(m1.adjoint().template triangularView<Eigen::Upper>() * (s1*v1.conjugate()), largerEps)); m3 = m1.template triangularView<Eigen::UnitUpper>(); + // check transposed cases: + m3 = m1.template triangularView<Eigen::Lower>(); + VERIFY((v1.transpose() * m3).isApprox(v1.transpose() * m1.template triangularView<Eigen::Lower>(), largerEps)); + VERIFY((v1.adjoint() * m3).isApprox(v1.adjoint() * m1.template triangularView<Eigen::Lower>(), largerEps)); + VERIFY((v1.adjoint() * m3.adjoint()).isApprox(v1.adjoint() * m1.template triangularView<Eigen::Lower>().adjoint(), largerEps)); + // TODO check with sub-matrices } |