diff options
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixVector.h | 42 | ||||
-rw-r--r-- | test/product_symm.cpp | 22 |
2 files changed, 45 insertions, 19 deletions
diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h index 8a10075f0..df7509f9a 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixVector.h +++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h @@ -201,5 +201,47 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> } }; +template<typename Lhs, typename Rhs, int RhsMode> +struct ei_traits<SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false> > + : ei_traits<ProductBase<SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false>, Lhs, Rhs> > +{}; + +template<typename Lhs, typename Rhs, int RhsMode> +struct SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false> + : public ProductBase<SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false>, Lhs, Rhs > +{ + EIGEN_PRODUCT_PUBLIC_INTERFACE(SelfadjointProductMatrix) + + enum { + RhsUpLo = RhsMode&(Upper|Lower) + }; + + SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} + + template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const + { + ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); + + const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); + const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) + * RhsBlasTraits::extractScalarFactor(m_rhs); + + ei_assert(dst.innerStride()==1 && "not implemented yet"); + + // transpose the product + ei_product_selfadjoint_vector<Scalar, Index, (ei_traits<_ActualRhsType>::Flags&RowMajorBit) ? ColMajor : RowMajor, int(RhsUpLo)==Upper ? Lower : Upper, + bool(RhsBlasTraits::NeedToConjugate), bool(LhsBlasTraits::NeedToConjugate)> + ( + rhs.rows(), // size + &rhs.coeff(0,0), rhs.outerStride(), // lhs info + &lhs.coeff(0), lhs.innerStride(), // rhs info + &dst.coeffRef(0), // result info + actualAlpha // scale factor + ); + } +}; + #endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_H diff --git a/test/product_symm.cpp b/test/product_symm.cpp index 5064237ab..5ddae30c0 100644 --- a/test/product_symm.cpp +++ b/test/product_symm.cpp @@ -24,23 +24,6 @@ #include "main.h" -template<int OtherSize> struct symm_extra { - template<typename M1, typename M2, typename Scalar> - static void run(M1& m1, M1& m2, M2& rhs2, M2& rhs22, M2& rhs23, Scalar s1, Scalar s2) - { - m2 = m1.template triangularView<Lower>(); - VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<Lower>(), - rhs23 = (rhs2) * (m1)); - VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<Lower>(), - rhs23 = (s2*rhs2) * (s1*m1)); - } -}; - -template<> struct symm_extra<1> { - template<typename M1, typename M2, typename Scalar> - static void run(M1&, M1&, M2&, M2&, M2&, Scalar, Scalar) {} -}; - template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, int othersize = OtherSize) { typedef typename NumTraits<Scalar>::Real RealScalar; @@ -105,8 +88,9 @@ template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, in VERIFY_IS_APPROX(rhs12.noalias() += s1 * ((m2.adjoint()).template selfadjointView<Lower>() * (s2*rhs3).conjugate()), rhs13 += (s1*m1.adjoint()) * (s2*rhs3).conjugate()); - // test matrix * selfadjoint - symm_extra<OtherSize>::run(m1,m2,rhs2,rhs22,rhs23,s1,s2); + m2 = m1.template triangularView<Lower>(); + VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<Lower>(), rhs23 = (rhs2) * (m1)); + VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<Lower>(), rhs23 = (s2*rhs2) * (s1*m1)); } |