aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/products/TriangularMatrixVector.h64
-rw-r--r--test/product_trmv.cpp6
2 files changed, 62 insertions, 8 deletions
diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h
index 1a2b183aa..935054a5a 100644
--- a/Eigen/src/Core/products/TriangularMatrixVector.h
+++ b/Eigen/src/Core/products/TriangularMatrixVector.h
@@ -25,12 +25,26 @@
#ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
#define EIGEN_TRIANGULARMATRIXVECTOR_H
-template<typename MatrixType, typename Rhs, typename Result,
+template<bool LhsIsTriangular, typename Lhs, typename Rhs, typename Result,
int Mode, bool ConjLhs, bool ConjRhs, int StorageOrder>
struct ei_product_triangular_vector_selector;
+template<typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs, int StorageOrder>
+struct ei_product_triangular_vector_selector<false,Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,StorageOrder>
+{
+ static EIGEN_DONT_INLINE void run(const Lhs& lhs, const Rhs& rhs, Result& res, typename ei_traits<Lhs>::Scalar alpha)
+ {
+ typedef Transpose<Rhs> TrRhs; TrRhs trRhs(rhs);
+ typedef Transpose<Lhs> TrLhs; TrLhs trLhs(lhs);
+ typedef Transpose<Result> TrRes; TrRes trRes(res);
+ ei_product_triangular_vector_selector<true,TrRhs,TrLhs,TrRes,
+ (Mode & UnitDiag) | (Mode & Lower) ? Upper : Lower, ConjRhs, ConjLhs, StorageOrder==RowMajor ? ColMajor : RowMajor>
+ ::run(trRhs,trLhs,trRes,alpha);
+ }
+};
+
template<typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs>
-struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,ColMajor>
+struct ei_product_triangular_vector_selector<true,Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Rhs::Index Index;
@@ -74,7 +88,7 @@ struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs
};
template<typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs>
-struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,RowMajor>
+struct ei_product_triangular_vector_selector<true,Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Rhs::Index Index;
@@ -119,12 +133,17 @@ struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs
* Wrapper to ei_product_triangular_vector
***************************************************************************/
-template<int Mode, /*bool LhsIsTriangular, */typename Lhs, typename Rhs>
-struct ei_traits<TriangularProduct<Mode,true,Lhs,false,Rhs,true> >
- : ei_traits<ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs> >
+template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
+struct ei_traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
+ : ei_traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
+{};
+
+template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
+struct ei_traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
+ : ei_traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
{};
-template<int Mode, /*bool LhsIsTriangular, */typename Lhs, typename Rhs>
+template<int Mode, typename Lhs, typename Rhs>
struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
: public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
{
@@ -143,7 +162,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
* RhsBlasTraits::extractScalarFactor(m_rhs);
ei_product_triangular_vector_selector
- <_ActualLhsType,_ActualRhsType,Dest,
+ <true,_ActualLhsType,_ActualRhsType,Dest,
Mode,
LhsBlasTraits::NeedToConjugate,
RhsBlasTraits::NeedToConjugate,
@@ -152,4 +171,33 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
}
};
+template<int Mode, typename Lhs, typename Rhs>
+struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
+ : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
+{
+ EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
+
+ TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
+
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
+ {
+
+ ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
+
+ const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
+ const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
+
+ Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
+ * RhsBlasTraits::extractScalarFactor(m_rhs);
+
+ ei_product_triangular_vector_selector
+ <false,_ActualLhsType,_ActualRhsType,Dest,
+ Mode,
+ LhsBlasTraits::NeedToConjugate,
+ RhsBlasTraits::NeedToConjugate,
+ (int(ei_traits<Rhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>
+ ::run(lhs,rhs,dst,actualAlpha);
+ }
+};
+
#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
diff --git a/test/product_trmv.cpp b/test/product_trmv.cpp
index f0962557a..2f5743187 100644
--- a/test/product_trmv.cpp
+++ b/test/product_trmv.cpp
@@ -76,6 +76,12 @@ template<typename MatrixType> void trmv(const MatrixType& m)
VERIFY((m3.adjoint() * (s1*v1.conjugate())).isApprox(m1.adjoint().template triangularView<Eigen::Upper>() * (s1*v1.conjugate()), largerEps));
m3 = m1.template triangularView<Eigen::UnitUpper>();
+ // check transposed cases:
+ m3 = m1.template triangularView<Eigen::Lower>();
+ VERIFY((v1.transpose() * m3).isApprox(v1.transpose() * m1.template triangularView<Eigen::Lower>(), largerEps));
+ VERIFY((v1.adjoint() * m3).isApprox(v1.adjoint() * m1.template triangularView<Eigen::Lower>(), largerEps));
+ VERIFY((v1.adjoint() * m3.adjoint()).isApprox(v1.adjoint() * m1.template triangularView<Eigen::Lower>().adjoint(), largerEps));
+
// TODO check with sub-matrices
}