diff options
author | Gael Guennebaud <g.gael@free.fr> | 2013-12-13 18:09:07 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2013-12-13 18:09:07 +0100 |
commit | 27c068e9d6230398b74a1c7b7146d7842c509de7 (patch) | |
tree | 63fb200a5fe280047d4b3c866c900f96b2eb52d9 | |
parent | e94fe4cc3e371f37b39f7b5f824cd3acc74af823 (diff) |
Make selfqdjoint products use evaluators
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 71 | ||||
-rw-r--r-- | Eigen/src/Core/SelfAdjointView.h | 43 | ||||
-rw-r--r-- | test/evaluators.cpp | 3 |
3 files changed, 113 insertions, 4 deletions
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index c3a9f0db4..f0eb57d67 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -628,6 +628,77 @@ protected: +/*************************************************************************** +* SelfAdjoint products +***************************************************************************/ + +template<typename Lhs, typename Rhs, int ProductTag> +struct generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag> + : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag> > +{ + typedef typename Product<Lhs,Rhs>::Scalar Scalar; + + template<typename Dest> + static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) + {// SelfadjointProductMatrix<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime> + // TODO bypass SelfadjointProductMatrix class + SelfadjointProductMatrix<typename Lhs::MatrixType,Lhs::Mode,false,Rhs,0,Rhs::IsVectorAtCompileTime>(lhs.nestedExpression(),rhs).scaleAndAddTo(dst, alpha); + } +}; + +template<typename Lhs, typename Rhs, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SelfAdjointShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar> + : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type +{ + typedef Product<Lhs, Rhs, DefaultProduct> XprType; + typedef typename XprType::PlainObject PlainObject; + typedef typename evaluator<PlainObject>::type Base; + + product_evaluator(const XprType& xpr) + : m_result(xpr.rows(), xpr.cols()) + { + ::new (static_cast<Base*>(this)) Base(m_result); + generic_product_impl<Lhs, Rhs, SelfAdjointShape, DenseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs()); + } + +protected: + PlainObject m_result; +}; + + +template<typename Lhs, typename Rhs, int ProductTag> +struct generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag> +: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag> > +{ + typedef typename Product<Lhs,Rhs>::Scalar Scalar; + + template<typename Dest> + static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) + {//SelfadjointProductMatrix<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false> + // TODO bypass SelfadjointProductMatrix class + SelfadjointProductMatrix<Lhs,0,Lhs::IsVectorAtCompileTime,typename Rhs::MatrixType,Rhs::Mode,false>(lhs,rhs.nestedExpression()).scaleAndAddTo(dst, alpha); + } +}; + +template<typename Lhs, typename Rhs, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DenseShape, SelfAdjointShape, typename Lhs::Scalar, typename Rhs::Scalar> + : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type +{ + typedef Product<Lhs, Rhs, DefaultProduct> XprType; + typedef typename XprType::PlainObject PlainObject; + typedef typename evaluator<PlainObject>::type Base; + + product_evaluator(const XprType& xpr) + : m_result(xpr.rows(), xpr.cols()) + { + ::new (static_cast<Base*>(this)) Base(m_result); + generic_product_impl<Lhs, Rhs, DenseShape, SelfAdjointShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs()); + } + +protected: + PlainObject m_result; +}; + } // end namespace internal diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index 8231e3f5c..079b987f8 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -50,11 +50,12 @@ template <typename Lhs, int LhsMode, bool LhsIsVector, struct SelfadjointProductMatrix; // FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ?? -template<typename MatrixType, unsigned int UpLo> class SelfAdjointView - : public TriangularBase<SelfAdjointView<MatrixType, UpLo> > +template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView + : public TriangularBase<SelfAdjointView<_MatrixType, UpLo> > { public: + typedef _MatrixType MatrixType; typedef TriangularBase<SelfAdjointView> Base; typedef typename internal::traits<SelfAdjointView>::MatrixTypeNested MatrixTypeNested; typedef typename internal::traits<SelfAdjointView>::MatrixTypeNestedCleaned MatrixTypeNestedCleaned; @@ -65,7 +66,8 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView typedef typename MatrixType::Index Index; enum { - Mode = internal::traits<SelfAdjointView>::Mode + Mode = internal::traits<SelfAdjointView>::Mode, + Flags = internal::traits<SelfAdjointView>::Flags }; typedef typename MatrixType::PlainObject PlainObject; @@ -111,6 +113,28 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView EIGEN_DEVICE_FUNC MatrixTypeNestedCleaned& nestedExpression() { return *const_cast<MatrixTypeNestedCleaned*>(&m_matrix); } +#ifdef EIGEN_TEST_EVALUATORS + + /** Efficient triangular matrix times vector/matrix product */ + template<typename OtherDerived> + EIGEN_DEVICE_FUNC + const Product<SelfAdjointView,OtherDerived> + operator*(const MatrixBase<OtherDerived>& rhs) const + { + return Product<SelfAdjointView,OtherDerived>(*this, rhs.derived()); + } + + /** Efficient vector/matrix times triangular matrix product */ + template<typename OtherDerived> friend + EIGEN_DEVICE_FUNC + const Product<OtherDerived,SelfAdjointView> + operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView& rhs) + { + return Product<OtherDerived,SelfAdjointView>(lhs.derived(),rhs); + } + +#else // EIGEN_TEST_EVALUATORS + /** Efficient self-adjoint matrix times vector/matrix product */ template<typename OtherDerived> EIGEN_DEVICE_FUNC @@ -132,6 +156,7 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView <OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false> (lhs.derived(),rhs.m_matrix); } +#endif /** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this: * \f$ this = this + \alpha u v^* + conj(\alpha) v u^* \f$ @@ -311,6 +336,18 @@ struct triangular_assignment_selector<Derived1, Derived2, SelfAdjoint|Lower, Dyn } }; +// TODO currently a selfadjoint expression has the form SelfAdjointView<.,.> +// in the future selfadjoint-ness should be defined by the expression traits +// such that Transpose<SelfAdjointView<.,.> > is valid. (currently TriangularBase::transpose() is overloaded to make it work) +template<typename MatrixType, unsigned int Mode> +struct evaluator_traits<SelfAdjointView<MatrixType,Mode> > +{ + typedef typename storage_kind_to_evaluator_kind<typename MatrixType::StorageKind>::Kind Kind; + typedef SelfAdjointShape Shape; + + static const int AssumeAliasing = 0; +}; + } // end namespace internal /*************************************************************************** diff --git a/test/evaluators.cpp b/test/evaluators.cpp index d4b737348..69a45661f 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -455,6 +455,7 @@ void test_evaluators() VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.triangularView<Upper>(),A), MatrixXd(A.triangularView<Upper>()*A)); - B.col(0).noalias() = prod( (2.1 * A.adjoint()).triangularView<UnitUpper>() , (A.row(0)).adjoint() ); + VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.selfadjointView<Upper>(),A), MatrixXd(A.selfadjointView<Upper>()*A)); + } } |