diff options
author | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2012-06-29 13:49:25 +0100 |
---|---|---|
committer | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2012-06-29 13:49:25 +0100 |
commit | d0b873822f52f5739868ba322dae5b3d0c399a4d (patch) | |
tree | 8c6af7d69559a24555fd9531a4d3dda2ae9ccdd1 | |
parent | 2393ceb38048506b799689e7bc109a4db5d09e99 (diff) |
Make product eval-at-once.
* Make product EvalAtOnce in cases OuterProduct, GemmProduct and
GemvProduct
* Ensure that product evaluators are nested inside EvalToTemp
evaluator
* As temporary kludge, evaluate expression to temporary in AllAtOnce
traversal and pass expression operator to evalTo()
-rw-r--r-- | Eigen/src/Core/AssignEvaluator.h | 8 | ||||
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 207 | ||||
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 44 | ||||
-rw-r--r-- | test/evaluators.cpp | 5 |
4 files changed, 173 insertions, 91 deletions
diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h index 08a2c696a..9be00067d 100644 --- a/Eigen/src/Core/AssignEvaluator.h +++ b/Eigen/src/Core/AssignEvaluator.h @@ -616,7 +616,13 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, AllAtOnceTraversal, NoU DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); - srcEvaluator.evalTo(dstEvaluator); + // Evaluate rhs in temporary to prevent aliasing problems in a = a * a; + // TODO: Be smarter about this + // TODO: Do not pass the xpr object to evalTo() + typename DstXprType::PlainObject tmp; + typename evaluator<typename DstXprType::PlainObject>::type tmpEvaluator(tmp); + srcEvaluator.evalTo(tmpEvaluator, tmp); + copy_using_evaluator(dst, tmp); } }; diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 768fa8950..808546ec1 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -3,7 +3,7 @@ // // Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com> // Copyright (C) 2011 Gael Guennebaud <gael.guennebaud@inria.fr> -// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk> +// Copyright (C) 2011-2012 Jitse Niesen <jitse@maths.leeds.ac.uk> // // Eigen is free software; you can redistribute it and/or // modify it under the terms of the GNU Lesser General Public @@ -42,24 +42,46 @@ struct evaluator_traits static const int HasEvalTo = 0; }; +// expression class for evaluating nested expression to a temporary + +template<typename ArgType> +class EvalToTemp; + // evaluator<T>::type is type of evaluator for T +// evaluator<T>::nestedType is type of evaluator if T is nested inside another evaluator + +template<typename T> +struct evaluator_impl +{ }; + +template<typename T, int Nested = evaluator_traits<T>::HasEvalTo> +struct evaluator_nested_type; template<typename T> -struct evaluator_impl {}; +struct evaluator_nested_type<T, 0> +{ + typedef evaluator_impl<T> type; +}; + +template<typename T> +struct evaluator_nested_type<T, 1> +{ + typedef evaluator_impl<EvalToTemp<T> > type; +}; template<typename T> struct evaluator { typedef evaluator_impl<T> type; + typedef typename evaluator_nested_type<T>::type nestedType; }; // TODO: Think about const-correctness template<typename T> struct evaluator<const T> -{ - typedef evaluator_impl<T> type; -}; + : evaluator<T> +{ }; // ---------- base class for all writable evaluators ---------- @@ -132,70 +154,6 @@ struct evaluator_impl_base } }; -// -------------------- Transpose -------------------- - -template<typename ArgType> -struct evaluator_impl<Transpose<ArgType> > - : evaluator_impl_base<Transpose<ArgType> > -{ - typedef Transpose<ArgType> XprType; - - evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {} - - typedef typename XprType::Index Index; - typedef typename XprType::Scalar Scalar; - typedef typename XprType::CoeffReturnType CoeffReturnType; - typedef typename XprType::PacketScalar PacketScalar; - typedef typename XprType::PacketReturnType PacketReturnType; - - CoeffReturnType coeff(Index row, Index col) const - { - return m_argImpl.coeff(col, row); - } - - CoeffReturnType coeff(Index index) const - { - return m_argImpl.coeff(index); - } - - Scalar& coeffRef(Index row, Index col) - { - return m_argImpl.coeffRef(col, row); - } - - typename XprType::Scalar& coeffRef(Index index) - { - return m_argImpl.coeffRef(index); - } - - template<int LoadMode> - PacketReturnType packet(Index row, Index col) const - { - return m_argImpl.template packet<LoadMode>(col, row); - } - - template<int LoadMode> - PacketReturnType packet(Index index) const - { - return m_argImpl.template packet<LoadMode>(index); - } - - template<int StoreMode> - void writePacket(Index row, Index col, const PacketScalar& x) - { - m_argImpl.template writePacket<StoreMode>(col, row, x); - } - - template<int StoreMode> - void writePacket(Index index, const PacketScalar& x) - { - m_argImpl.template writePacket<StoreMode>(index, x); - } - -protected: - typename evaluator<ArgType>::type m_argImpl; -}; - // -------------------- Matrix and Array -------------------- // // evaluator_impl<PlainObjectBase> is a common base class for the @@ -285,6 +243,89 @@ struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > { } }; +// -------------------- EvalToTemp -------------------- + +template<typename ArgType> +struct evaluator_impl<EvalToTemp<ArgType> > + : evaluator_impl<typename ArgType::PlainObject> +{ + typedef typename ArgType::PlainObject PlainObject; + typedef evaluator_impl<PlainObject> BaseType; + + evaluator_impl(const ArgType& arg) + : BaseType(m_result) + { + copy_using_evaluator(m_result, arg); + }; + +protected: + PlainObject m_result; +}; + +// -------------------- Transpose -------------------- + +template<typename ArgType> +struct evaluator_impl<Transpose<ArgType> > + : evaluator_impl_base<Transpose<ArgType> > +{ + typedef Transpose<ArgType> XprType; + + evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {} + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; + + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(col, row); + } + + CoeffReturnType coeff(Index index) const + { + return m_argImpl.coeff(index); + } + + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(col, row); + } + + typename XprType::Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index); + } + + template<int LoadMode> + PacketReturnType packet(Index row, Index col) const + { + return m_argImpl.template packet<LoadMode>(col, row); + } + + template<int LoadMode> + PacketReturnType packet(Index index) const + { + return m_argImpl.template packet<LoadMode>(index); + } + + template<int StoreMode> + void writePacket(Index row, Index col, const PacketScalar& x) + { + m_argImpl.template writePacket<StoreMode>(col, row, x); + } + + template<int StoreMode> + void writePacket(Index index, const PacketScalar& x) + { + m_argImpl.template writePacket<StoreMode>(index, x); + } + +protected: + typename evaluator<ArgType>::nestedType m_argImpl; +}; + // -------------------- CwiseNullaryOp -------------------- template<typename NullaryOp, typename PlainObjectType> @@ -366,7 +407,7 @@ struct evaluator_impl<CwiseUnaryOp<UnaryOp, ArgType> > protected: const UnaryOp m_functor; - typename evaluator<ArgType>::type m_argImpl; + typename evaluator<ArgType>::nestedType m_argImpl; }; // -------------------- CwiseBinaryOp -------------------- @@ -412,8 +453,8 @@ struct evaluator_impl<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > protected: const BinaryOp m_functor; - typename evaluator<Lhs>::type m_lhsImpl; - typename evaluator<Rhs>::type m_rhsImpl; + typename evaluator<Lhs>::nestedType m_lhsImpl; + typename evaluator<Rhs>::nestedType m_rhsImpl; }; // -------------------- CwiseUnaryView -------------------- @@ -455,7 +496,7 @@ struct evaluator_impl<CwiseUnaryView<UnaryOp, ArgType> > protected: const UnaryOp m_unaryOp; - typename evaluator<ArgType>::type m_argImpl; + typename evaluator<ArgType>::nestedType m_argImpl; }; // -------------------- Map -------------------- @@ -626,7 +667,7 @@ struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDir } protected: - typename evaluator<ArgType>::type m_argImpl; + typename evaluator<ArgType>::nestedType m_argImpl; // TODO: Get rid of m_startRow, m_startCol if known at compile time Index m_startRow; @@ -681,9 +722,9 @@ struct evaluator_impl<Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType } protected: - typename evaluator<ConditionMatrixType>::type m_conditionImpl; - typename evaluator<ThenMatrixType>::type m_thenImpl; - typename evaluator<ElseMatrixType>::type m_elseImpl; + typename evaluator<ConditionMatrixType>::nestedType m_conditionImpl; + typename evaluator<ThenMatrixType>::nestedType m_thenImpl; + typename evaluator<ElseMatrixType>::nestedType m_elseImpl; }; @@ -731,7 +772,7 @@ struct evaluator_impl<Replicate<ArgType, RowFactor, ColFactor> > } protected: - typename evaluator<ArgType>::type m_argImpl; + typename evaluator<ArgType>::nestedType m_argImpl; Index m_rows; // TODO: Get rid of this if known at compile time Index m_cols; }; @@ -834,7 +875,7 @@ struct evaluator_impl_wrapper_base } protected: - typename evaluator<ArgType>::type m_argImpl; + typename evaluator<ArgType>::nestedType m_argImpl; }; template<typename ArgType> @@ -949,7 +990,7 @@ struct evaluator_impl<Reverse<ArgType, Direction> > } protected: - typename evaluator<ArgType>::type m_argImpl; + typename evaluator<ArgType>::nestedType m_argImpl; Index m_rows; // TODO: Don't use if known at compile time or not needed Index m_cols; }; @@ -993,7 +1034,7 @@ struct evaluator_impl<Diagonal<ArgType, DiagIndex> > } protected: - typename evaluator<ArgType>::type m_argImpl; + typename evaluator<ArgType>::nestedType m_argImpl; Index m_index; // TODO: Don't use if known at compile time private: @@ -1069,7 +1110,7 @@ struct evaluator_impl<SwapWrapper<ArgType> > } protected: - typename evaluator<ArgType>::type m_argImpl; + typename evaluator<ArgType>::nestedType m_argImpl; }; @@ -1133,7 +1174,7 @@ struct evaluator_impl<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> > } protected: - typename evaluator<LhsXpr>::type m_argImpl; + typename evaluator<LhsXpr>::nestedType m_argImpl; const BinaryOp m_functor; }; diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index aadaa9303..e814a4710 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -50,6 +50,14 @@ struct evaluator_impl<Product<Lhs, Rhs> > { } }; +template<typename XprType, typename ProductType> +struct product_evaluator_traits_dispatcher; + +template<typename Lhs, typename Rhs> +struct evaluator_traits<Product<Lhs, Rhs> > + : product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, typename ProductReturnType<Lhs, Rhs>::Type> +{ }; + // Case 1: Evaluate all at once // // We can view the GeneralProduct class as a part of the product evaluator. @@ -57,13 +65,20 @@ struct evaluator_impl<Product<Lhs, Rhs> > // InnerProduct is special because GeneralProduct does not have an evalTo() method in this case. template<typename Lhs, typename Rhs> +struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> > +{ + static const int HasEvalTo = 0; +}; + +template<typename Lhs, typename Rhs> struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> > : public evaluator<typename Product<Lhs, Rhs>::PlainObject>::type { typedef Product<Lhs, Rhs> XprType; typedef typename XprType::PlainObject PlainObject; typedef typename evaluator<PlainObject>::type evaluator_base; - + + // TODO: Computation is too early (?) product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result) { m_result.coeffRef(0,0) = (xpr.lhs().transpose().cwiseProduct(xpr.rhs())).sum(); @@ -77,21 +92,30 @@ protected: // TODO: GeneralProduct should take evaluators, not expression objects. template<typename Lhs, typename Rhs, int ProductType> +struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> > +{ + static const int HasEvalTo = 1; +}; + +template<typename Lhs, typename Rhs, int ProductType> struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> > - : public evaluator<typename Product<Lhs, Rhs>::PlainObject>::type { typedef Product<Lhs, Rhs> XprType; typedef typename XprType::PlainObject PlainObject; typedef typename evaluator<PlainObject>::type evaluator_base; - product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result) + product_evaluator_dispatcher(const XprType& xpr) : m_xpr(xpr) + { } + + template<typename DstEvaluatorType, typename DstXprType> + void evalTo(DstEvaluatorType /* not used */, DstXprType& dst) { - m_result.resize(xpr.rows(), xpr.cols()); - GeneralProduct<Lhs, Rhs, ProductType>(xpr.lhs(), xpr.rhs()).evalTo(m_result); + dst.resize(m_xpr.rows(), m_xpr.cols()); + GeneralProduct<Lhs, Rhs, ProductType>(m_xpr.lhs(), m_xpr.rhs()).evalTo(dst); } -protected: - PlainObject m_result; +protected: + const XprType& m_xpr; }; // Case 2: Evaluate coeff by coeff @@ -107,6 +131,12 @@ template<int StorageOrder, int UnrollingIndex, typename Lhs, typename Rhs, typen struct etor_product_packet_impl; template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags> +struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> > +{ + static const int HasEvalTo = 0; +}; + +template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags> struct product_evaluator_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> > : evaluator_impl_base<Product<Lhs, Rhs> > { diff --git a/test/evaluators.cpp b/test/evaluators.cpp index 62ba5b126..3081d7858 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -65,6 +65,11 @@ void test_evaluators() VERIFY_IS_APPROX_EVALUATOR2(d, s * prod(a,b), s * a*b); VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b).transpose(), (a*b).transpose()); VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b) + prod(b,c), a*b + b*c); + + // check that prod works even with aliasing present + c = a*a; + copy_using_evaluator(a, prod(a,a)); + VERIFY_IS_APPROX(a,c); } { |