aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/SelfAdjointView.h99
-rw-r--r--test/product_selfadjoint.cpp7
2 files changed, 69 insertions, 37 deletions
diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h
index 7f5fd7533..c64ebc174 100644
--- a/Eigen/src/Core/SelfAdjointView.h
+++ b/Eigen/src/Core/SelfAdjointView.h
@@ -53,8 +53,9 @@ struct ei_traits<SelfAdjointView<MatrixType, TriangularPart> > : ei_traits<Matri
};
};
-template<typename Lhs,typename Rhs,bool RhsIsVector=Rhs::IsVectorAtCompileTime>
-struct ei_selfadjoint_matrix_product_returntype;
+template <typename Lhs, int LhsMode, bool LhsIsVector,
+ typename Rhs, int RhsMode, bool RhsIsVector>
+struct ei_selfadjoint_product_returntype;
// FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ??
template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
@@ -99,10 +100,22 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
/** Efficient self-adjoint matrix times vector/matrix product */
template<typename OtherDerived>
- ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>
+ ei_selfadjoint_product_returntype<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
operator*(const MatrixBase<OtherDerived>& rhs) const
{
- return ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>(*this, rhs.derived());
+ return ei_selfadjoint_product_returntype
+ <MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
+ (m_matrix, rhs.derived());
+ }
+
+ /** Efficient vector/matrix times self-adjoint matrix product */
+ template<typename OtherDerived> friend
+ ei_selfadjoint_product_returntype<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
+ operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView& rhs)
+ {
+ return ei_selfadjoint_product_returntype
+ <OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
+ (lhs.derived(),rhs.m_matrix);
}
/** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this:
@@ -125,6 +138,14 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
const typename MatrixType::Nested m_matrix;
};
+
+// template<typename OtherDerived, typename MatrixType, unsigned int UpLo>
+// ei_selfadjoint_matrix_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> >
+// operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView<MatrixType,UpLo>& rhs)
+// {
+// return ei_matrix_selfadjoint_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> >(lhs.derived(),rhs);
+// }
+
template<typename Derived1, typename Derived2, int UnrollCount, bool ClearOpposite>
struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, UnrollCount, ClearOpposite>
{
@@ -163,14 +184,14 @@ struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, Dynami
* Wrapper to ei_product_selfadjoint_vector
***************************************************************************/
-template<typename Lhs,typename Rhs>
-struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
- : public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>,
+template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
+struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>
+ : public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>,
Matrix<typename ei_traits<Rhs>::Scalar,
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
- ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs)
+ ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
@@ -178,10 +199,10 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
{
dst.resize(m_rhs.rows(), m_rhs.cols());
ei_product_selfadjoint_vector<typename Lhs::Scalar,ei_traits<Lhs>::Flags&RowMajorBit,
- Lhs::Mode&(UpperTriangularBit|LowerTriangularBit)>
+ LhsMode&(UpperTriangularBit|LowerTriangularBit)>
(
m_lhs.rows(), // size
- m_lhs._expression().data(), // lhs
+ m_lhs.data(), // lhs
m_lhs.stride(), // lhsStride,
m_rhs.data(), // rhs
// int rhsIncr,
@@ -189,7 +210,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
);
}
- const Lhs m_lhs;
+ const typename Lhs::Nested m_lhs;
const typename Rhs::Nested m_rhs;
};
@@ -197,25 +218,36 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
* Wrapper to ei_product_selfadjoint_matrix
***************************************************************************/
-template<typename Lhs,typename Rhs>
-struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
- : public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>,
+template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
+struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>
+ : public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>,
Matrix<typename ei_traits<Rhs>::Scalar,
- Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
+ Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{
- ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs)
+ ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
typedef typename Lhs::Scalar Scalar;
- typedef typename Rhs::Nested RhsNested;
- typedef typename ei_cleantype<RhsNested>::type _RhsNested;
- typedef typename ei_traits<Lhs>::ExpressionType LhsExpr;
- typedef typename LhsExpr::Nested LhsNested;
+ typedef typename Lhs::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
+ typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
+ typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+ typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
+
+ typedef typename Rhs::Nested RhsNested;
+ typedef typename ei_cleantype<RhsNested>::type _RhsNested;
+ typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
+ typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+ typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
- enum { UpLo = ei_traits<Lhs>::Mode&(UpperTriangularBit|LowerTriangularBit) };
+ enum {
+ LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit),
+ LhsIsSelfAdjoint = (LhsMode&SelfAdjointBit)==SelfAdjointBit,
+ RhsUpLo = RhsMode&(UpperTriangularBit|LowerTriangularBit),
+ RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
+ };
template<typename Dest> inline void _addTo(Dest& dst) const
{ evalTo(dst,1); }
@@ -231,26 +263,19 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
{
- typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
- typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
-
- typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
- typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
-
- typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
- typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
-
- const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs._expression());
+ const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression())
+ Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs);
ei_product_selfadjoint_matrix<Scalar,
- EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,
- ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, true,
- NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)),
- ei_traits<Rhs>::Flags &RowMajorBit ? RowMajor : ColMajor, false, bool(RhsBlasTraits::NeedToConjugate),
+ EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,
+ ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint,
+ NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)),
+ EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,
+ ei_traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
+ NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,bool(RhsBlasTraits::NeedToConjugate)),
ei_traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor>
::run(
lhs.rows(), rhs.cols(), // sizes
@@ -261,7 +286,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
);
}
- const Lhs m_lhs;
+ const LhsNested m_lhs;
const RhsNested m_rhs;
};
diff --git a/test/product_selfadjoint.cpp b/test/product_selfadjoint.cpp
index 814d542e4..44bafad93 100644
--- a/test/product_selfadjoint.cpp
+++ b/test/product_selfadjoint.cpp
@@ -138,6 +138,13 @@ template<typename MatrixType> void symm(const MatrixType& m)
m2 = m1.template triangularView<UpperTriangular>();
VERIFY_IS_APPROX(rhs32 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate(),
rhs33 = (s1*m1.adjoint()) * (s2*rhs3).conjugate());
+
+ // test matrix * selfadjoint
+ m2 = m1.template triangularView<LowerTriangular>();
+ VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<LowerTriangular>(),
+ rhs23 = (rhs2) * (m1));
+ VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<LowerTriangular>(),
+ rhs23 = (s2*rhs2) * (s1*m1));
}
void test_product_selfadjoint()
{