diff options
-rw-r--r-- | Eigen/src/Core/SelfAdjointView.h | 99 | ||||
-rw-r--r-- | test/product_selfadjoint.cpp | 7 |
2 files changed, 69 insertions, 37 deletions
diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index 7f5fd7533..c64ebc174 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -53,8 +53,9 @@ struct ei_traits<SelfAdjointView<MatrixType, TriangularPart> > : ei_traits<Matri }; }; -template<typename Lhs,typename Rhs,bool RhsIsVector=Rhs::IsVectorAtCompileTime> -struct ei_selfadjoint_matrix_product_returntype; +template <typename Lhs, int LhsMode, bool LhsIsVector, + typename Rhs, int RhsMode, bool RhsIsVector> +struct ei_selfadjoint_product_returntype; // FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ?? template<typename MatrixType, unsigned int UpLo> class SelfAdjointView @@ -99,10 +100,22 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView /** Efficient self-adjoint matrix times vector/matrix product */ template<typename OtherDerived> - ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived> + ei_selfadjoint_product_returntype<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime> operator*(const MatrixBase<OtherDerived>& rhs) const { - return ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>(*this, rhs.derived()); + return ei_selfadjoint_product_returntype + <MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime> + (m_matrix, rhs.derived()); + } + + /** Efficient vector/matrix times self-adjoint matrix product */ + template<typename OtherDerived> friend + ei_selfadjoint_product_returntype<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false> + operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView& rhs) + { + return ei_selfadjoint_product_returntype + <OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false> + (lhs.derived(),rhs.m_matrix); } /** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this: @@ -125,6 +138,14 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView const typename MatrixType::Nested m_matrix; }; + +// template<typename OtherDerived, typename MatrixType, unsigned int UpLo> +// ei_selfadjoint_matrix_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> > +// operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView<MatrixType,UpLo>& rhs) +// { +// return ei_matrix_selfadjoint_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> >(lhs.derived(),rhs); +// } + template<typename Derived1, typename Derived2, int UnrollCount, bool ClearOpposite> struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, UnrollCount, ClearOpposite> { @@ -163,14 +184,14 @@ struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, Dynami * Wrapper to ei_product_selfadjoint_vector ***************************************************************************/ -template<typename Lhs,typename Rhs> -struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true> - : public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>, +template<typename Lhs, int LhsMode, typename Rhs, int RhsMode> +struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true> + : public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>, Matrix<typename ei_traits<Rhs>::Scalar, Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > { typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested; - ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs) + ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {} @@ -178,10 +199,10 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true> { dst.resize(m_rhs.rows(), m_rhs.cols()); ei_product_selfadjoint_vector<typename Lhs::Scalar,ei_traits<Lhs>::Flags&RowMajorBit, - Lhs::Mode&(UpperTriangularBit|LowerTriangularBit)> + LhsMode&(UpperTriangularBit|LowerTriangularBit)> ( m_lhs.rows(), // size - m_lhs._expression().data(), // lhs + m_lhs.data(), // lhs m_lhs.stride(), // lhsStride, m_rhs.data(), // rhs // int rhsIncr, @@ -189,7 +210,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true> ); } - const Lhs m_lhs; + const typename Lhs::Nested m_lhs; const typename Rhs::Nested m_rhs; }; @@ -197,25 +218,36 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true> * Wrapper to ei_product_selfadjoint_matrix ***************************************************************************/ -template<typename Lhs,typename Rhs> -struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false> - : public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>, +template<typename Lhs, int LhsMode, typename Rhs, int RhsMode> +struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false> + : public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>, Matrix<typename ei_traits<Rhs>::Scalar, - Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > + Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > { - ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs) + ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {} typedef typename Lhs::Scalar Scalar; - typedef typename Rhs::Nested RhsNested; - typedef typename ei_cleantype<RhsNested>::type _RhsNested; - typedef typename ei_traits<Lhs>::ExpressionType LhsExpr; - typedef typename LhsExpr::Nested LhsNested; + typedef typename Lhs::Nested LhsNested; typedef typename ei_cleantype<LhsNested>::type _LhsNested; + typedef ei_blas_traits<_LhsNested> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType; + + typedef typename Rhs::Nested RhsNested; + typedef typename ei_cleantype<RhsNested>::type _RhsNested; + typedef ei_blas_traits<_RhsNested> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType; - enum { UpLo = ei_traits<Lhs>::Mode&(UpperTriangularBit|LowerTriangularBit) }; + enum { + LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit), + LhsIsSelfAdjoint = (LhsMode&SelfAdjointBit)==SelfAdjointBit, + RhsUpLo = RhsMode&(UpperTriangularBit|LowerTriangularBit), + RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit + }; template<typename Dest> inline void _addTo(Dest& dst) const { evalTo(dst,1); } @@ -231,26 +263,19 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false> template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const { - typedef ei_blas_traits<_LhsNested> LhsBlasTraits; - typedef ei_blas_traits<_RhsNested> RhsBlasTraits; - - typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; - typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; - - typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType; - typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType; - - const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs._expression()); + const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression()) + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) * RhsBlasTraits::extractScalarFactor(m_rhs); ei_product_selfadjoint_matrix<Scalar, - EIGEN_LOGICAL_XOR(UpLo==UpperTriangular, - ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, true, - NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)), - ei_traits<Rhs>::Flags &RowMajorBit ? RowMajor : ColMajor, false, bool(RhsBlasTraits::NeedToConjugate), + EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular, + ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint, + NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)), + EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular, + ei_traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint, + NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,bool(RhsBlasTraits::NeedToConjugate)), ei_traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor> ::run( lhs.rows(), rhs.cols(), // sizes @@ -261,7 +286,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false> ); } - const Lhs m_lhs; + const LhsNested m_lhs; const RhsNested m_rhs; }; diff --git a/test/product_selfadjoint.cpp b/test/product_selfadjoint.cpp index 814d542e4..44bafad93 100644 --- a/test/product_selfadjoint.cpp +++ b/test/product_selfadjoint.cpp @@ -138,6 +138,13 @@ template<typename MatrixType> void symm(const MatrixType& m) m2 = m1.template triangularView<UpperTriangular>(); VERIFY_IS_APPROX(rhs32 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate(), rhs33 = (s1*m1.adjoint()) * (s2*rhs3).conjugate()); + + // test matrix * selfadjoint + m2 = m1.template triangularView<LowerTriangular>(); + VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<LowerTriangular>(), + rhs23 = (rhs2) * (m1)); + VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<LowerTriangular>(), + rhs23 = (s2*rhs2) * (s1*m1)); } void test_product_selfadjoint() { |