aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-08-19 14:05:21 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-08-19 14:05:21 +0200
commit5354ffbb4f4d8db29097dba729da64be82b1f7b2 (patch)
tree0fcbc708814def3703b479a87fc7b612433eb397
parent6264755dd38a145d52fa280fe98d54136f2e72e9 (diff)
add missing specialization for vector * selfadjoint
-rw-r--r--Eigen/src/Core/products/SelfadjointMatrixVector.h42
-rw-r--r--test/product_symm.cpp22
2 files changed, 45 insertions, 19 deletions
diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h
index 8a10075f0..df7509f9a 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixVector.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h
@@ -201,5 +201,47 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
}
};
+template<typename Lhs, typename Rhs, int RhsMode>
+struct ei_traits<SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false> >
+ : ei_traits<ProductBase<SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false>, Lhs, Rhs> >
+{};
+
+template<typename Lhs, typename Rhs, int RhsMode>
+struct SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false>
+ : public ProductBase<SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false>, Lhs, Rhs >
+{
+ EIGEN_PRODUCT_PUBLIC_INTERFACE(SelfadjointProductMatrix)
+
+ enum {
+ RhsUpLo = RhsMode&(Upper|Lower)
+ };
+
+ SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
+
+ template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
+ {
+ ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
+
+ const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
+ const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
+
+ Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
+ * RhsBlasTraits::extractScalarFactor(m_rhs);
+
+ ei_assert(dst.innerStride()==1 && "not implemented yet");
+
+ // transpose the product
+ ei_product_selfadjoint_vector<Scalar, Index, (ei_traits<_ActualRhsType>::Flags&RowMajorBit) ? ColMajor : RowMajor, int(RhsUpLo)==Upper ? Lower : Upper,
+ bool(RhsBlasTraits::NeedToConjugate), bool(LhsBlasTraits::NeedToConjugate)>
+ (
+ rhs.rows(), // size
+ &rhs.coeff(0,0), rhs.outerStride(), // lhs info
+ &lhs.coeff(0), lhs.innerStride(), // rhs info
+ &dst.coeffRef(0), // result info
+ actualAlpha // scale factor
+ );
+ }
+};
+
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_H
diff --git a/test/product_symm.cpp b/test/product_symm.cpp
index 5064237ab..5ddae30c0 100644
--- a/test/product_symm.cpp
+++ b/test/product_symm.cpp
@@ -24,23 +24,6 @@
#include "main.h"
-template<int OtherSize> struct symm_extra {
- template<typename M1, typename M2, typename Scalar>
- static void run(M1& m1, M1& m2, M2& rhs2, M2& rhs22, M2& rhs23, Scalar s1, Scalar s2)
- {
- m2 = m1.template triangularView<Lower>();
- VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<Lower>(),
- rhs23 = (rhs2) * (m1));
- VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<Lower>(),
- rhs23 = (s2*rhs2) * (s1*m1));
- }
-};
-
-template<> struct symm_extra<1> {
- template<typename M1, typename M2, typename Scalar>
- static void run(M1&, M1&, M2&, M2&, M2&, Scalar, Scalar) {}
-};
-
template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, int othersize = OtherSize)
{
typedef typename NumTraits<Scalar>::Real RealScalar;
@@ -105,8 +88,9 @@ template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, in
VERIFY_IS_APPROX(rhs12.noalias() += s1 * ((m2.adjoint()).template selfadjointView<Lower>() * (s2*rhs3).conjugate()),
rhs13 += (s1*m1.adjoint()) * (s2*rhs3).conjugate());
- // test matrix * selfadjoint
- symm_extra<OtherSize>::run(m1,m2,rhs2,rhs22,rhs23,s1,s2);
+ m2 = m1.template triangularView<Lower>();
+ VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<Lower>(), rhs23 = (rhs2) * (m1));
+ VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<Lower>(), rhs23 = (s2*rhs2) * (s1*m1));
}