aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2013-12-13 18:09:07 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2013-12-13 18:09:07 +0100
commit27c068e9d6230398b74a1c7b7146d7842c509de7 (patch)
tree63fb200a5fe280047d4b3c866c900f96b2eb52d9
parente94fe4cc3e371f37b39f7b5f824cd3acc74af823 (diff)
Make selfqdjoint products use evaluators
-rw-r--r--Eigen/src/Core/ProductEvaluators.h71
-rw-r--r--Eigen/src/Core/SelfAdjointView.h43
-rw-r--r--test/evaluators.cpp3
3 files changed, 113 insertions, 4 deletions
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index c3a9f0db4..f0eb57d67 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -628,6 +628,77 @@ protected:
+/***************************************************************************
+* SelfAdjoint products
+***************************************************************************/
+
+template<typename Lhs, typename Rhs, int ProductTag>
+struct generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag>
+ : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag> >
+{
+ typedef typename Product<Lhs,Rhs>::Scalar Scalar;
+
+ template<typename Dest>
+ static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
+ {// SelfadjointProductMatrix<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
+ // TODO bypass SelfadjointProductMatrix class
+ SelfadjointProductMatrix<typename Lhs::MatrixType,Lhs::Mode,false,Rhs,0,Rhs::IsVectorAtCompileTime>(lhs.nestedExpression(),rhs).scaleAndAddTo(dst, alpha);
+ }
+};
+
+template<typename Lhs, typename Rhs, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SelfAdjointShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
+{
+ typedef Product<Lhs, Rhs, DefaultProduct> XprType;
+ typedef typename XprType::PlainObject PlainObject;
+ typedef typename evaluator<PlainObject>::type Base;
+
+ product_evaluator(const XprType& xpr)
+ : m_result(xpr.rows(), xpr.cols())
+ {
+ ::new (static_cast<Base*>(this)) Base(m_result);
+ generic_product_impl<Lhs, Rhs, SelfAdjointShape, DenseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
+ }
+
+protected:
+ PlainObject m_result;
+};
+
+
+template<typename Lhs, typename Rhs, int ProductTag>
+struct generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag>
+: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag> >
+{
+ typedef typename Product<Lhs,Rhs>::Scalar Scalar;
+
+ template<typename Dest>
+ static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
+ {//SelfadjointProductMatrix<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
+ // TODO bypass SelfadjointProductMatrix class
+ SelfadjointProductMatrix<Lhs,0,Lhs::IsVectorAtCompileTime,typename Rhs::MatrixType,Rhs::Mode,false>(lhs,rhs.nestedExpression()).scaleAndAddTo(dst, alpha);
+ }
+};
+
+template<typename Lhs, typename Rhs, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DenseShape, SelfAdjointShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
+{
+ typedef Product<Lhs, Rhs, DefaultProduct> XprType;
+ typedef typename XprType::PlainObject PlainObject;
+ typedef typename evaluator<PlainObject>::type Base;
+
+ product_evaluator(const XprType& xpr)
+ : m_result(xpr.rows(), xpr.cols())
+ {
+ ::new (static_cast<Base*>(this)) Base(m_result);
+ generic_product_impl<Lhs, Rhs, DenseShape, SelfAdjointShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
+ }
+
+protected:
+ PlainObject m_result;
+};
+
} // end namespace internal
diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h
index 8231e3f5c..079b987f8 100644
--- a/Eigen/src/Core/SelfAdjointView.h
+++ b/Eigen/src/Core/SelfAdjointView.h
@@ -50,11 +50,12 @@ template <typename Lhs, int LhsMode, bool LhsIsVector,
struct SelfadjointProductMatrix;
// FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ??
-template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
- : public TriangularBase<SelfAdjointView<MatrixType, UpLo> >
+template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView
+ : public TriangularBase<SelfAdjointView<_MatrixType, UpLo> >
{
public:
+ typedef _MatrixType MatrixType;
typedef TriangularBase<SelfAdjointView> Base;
typedef typename internal::traits<SelfAdjointView>::MatrixTypeNested MatrixTypeNested;
typedef typename internal::traits<SelfAdjointView>::MatrixTypeNestedCleaned MatrixTypeNestedCleaned;
@@ -65,7 +66,8 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
typedef typename MatrixType::Index Index;
enum {
- Mode = internal::traits<SelfAdjointView>::Mode
+ Mode = internal::traits<SelfAdjointView>::Mode,
+ Flags = internal::traits<SelfAdjointView>::Flags
};
typedef typename MatrixType::PlainObject PlainObject;
@@ -111,6 +113,28 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
EIGEN_DEVICE_FUNC
MatrixTypeNestedCleaned& nestedExpression() { return *const_cast<MatrixTypeNestedCleaned*>(&m_matrix); }
+#ifdef EIGEN_TEST_EVALUATORS
+
+ /** Efficient triangular matrix times vector/matrix product */
+ template<typename OtherDerived>
+ EIGEN_DEVICE_FUNC
+ const Product<SelfAdjointView,OtherDerived>
+ operator*(const MatrixBase<OtherDerived>& rhs) const
+ {
+ return Product<SelfAdjointView,OtherDerived>(*this, rhs.derived());
+ }
+
+ /** Efficient vector/matrix times triangular matrix product */
+ template<typename OtherDerived> friend
+ EIGEN_DEVICE_FUNC
+ const Product<OtherDerived,SelfAdjointView>
+ operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView& rhs)
+ {
+ return Product<OtherDerived,SelfAdjointView>(lhs.derived(),rhs);
+ }
+
+#else // EIGEN_TEST_EVALUATORS
+
/** Efficient self-adjoint matrix times vector/matrix product */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
@@ -132,6 +156,7 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
(lhs.derived(),rhs.m_matrix);
}
+#endif
/** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this:
* \f$ this = this + \alpha u v^* + conj(\alpha) v u^* \f$
@@ -311,6 +336,18 @@ struct triangular_assignment_selector<Derived1, Derived2, SelfAdjoint|Lower, Dyn
}
};
+// TODO currently a selfadjoint expression has the form SelfAdjointView<.,.>
+// in the future selfadjoint-ness should be defined by the expression traits
+// such that Transpose<SelfAdjointView<.,.> > is valid. (currently TriangularBase::transpose() is overloaded to make it work)
+template<typename MatrixType, unsigned int Mode>
+struct evaluator_traits<SelfAdjointView<MatrixType,Mode> >
+{
+ typedef typename storage_kind_to_evaluator_kind<typename MatrixType::StorageKind>::Kind Kind;
+ typedef SelfAdjointShape Shape;
+
+ static const int AssumeAliasing = 0;
+};
+
} // end namespace internal
/***************************************************************************
diff --git a/test/evaluators.cpp b/test/evaluators.cpp
index d4b737348..69a45661f 100644
--- a/test/evaluators.cpp
+++ b/test/evaluators.cpp
@@ -455,6 +455,7 @@ void test_evaluators()
VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.triangularView<Upper>(),A), MatrixXd(A.triangularView<Upper>()*A));
- B.col(0).noalias() = prod( (2.1 * A.adjoint()).triangularView<UnitUpper>() , (A.row(0)).adjoint() );
+ VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.selfadjointView<Upper>(),A), MatrixXd(A.selfadjointView<Upper>()*A));
+
}
}