diff options
author | Gael Guennebaud <g.gael@free.fr> | 2013-12-02 17:54:38 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2013-12-02 17:54:38 +0100 |
commit | f0b82c3ab972a1eafd6aa4f08f4eaffa0d6f1e55 (patch) | |
tree | e7117c7773f19786752b4d702a7e7205dc3d1cf2 /Eigen/src/Core/Redux.h | |
parent | 6f1a0479b3a8176f4d8db6513cbb39c027c34b7f (diff) |
Make reductions compatible with evaluators
Diffstat (limited to 'Eigen/src/Core/Redux.h')
-rw-r--r-- | Eigen/src/Core/Redux.h | 81 |
1 files changed, 74 insertions, 7 deletions
diff --git a/Eigen/src/Core/Redux.h b/Eigen/src/Core/Redux.h index b2c775d90..1b6a4177a 100644 --- a/Eigen/src/Core/Redux.h +++ b/Eigen/src/Core/Redux.h @@ -174,7 +174,7 @@ struct redux_impl<Func, Derived, DefaultTraversal, NoUnrolling> typedef typename Derived::Scalar Scalar; typedef typename Derived::Index Index; EIGEN_DEVICE_FUNC - static EIGEN_STRONG_INLINE Scalar run(const Derived& mat, const Func& func) + static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func) { eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix"); Scalar res; @@ -200,10 +200,10 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, NoUnrolling> typedef typename packet_traits<Scalar>::type PacketScalar; typedef typename Derived::Index Index; - static Scalar run(const Derived& mat, const Func& func) + static Scalar run(const Derived &mat, const Func& func) { const Index size = mat.size(); - eigen_assert(size && "you are using an empty matrix"); + const Index packetSize = packet_traits<Scalar>::size; const Index alignedStart = internal::first_aligned(mat); enum { @@ -258,7 +258,7 @@ struct redux_impl<Func, Derived, SliceVectorizedTraversal, NoUnrolling> typedef typename packet_traits<Scalar>::type PacketScalar; typedef typename Derived::Index Index; - static Scalar run(const Derived& mat, const Func& func) + static Scalar run(const Derived &mat, const Func& func) { eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix"); const Index innerSize = mat.innerSize(); @@ -300,9 +300,8 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling> Size = Derived::SizeAtCompileTime, VectorizedSize = (Size / PacketSize) * PacketSize }; - static EIGEN_STRONG_INLINE Scalar run(const Derived& mat, const Func& func) + static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func) { - eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix"); Scalar res = func.predux(redux_vec_unroller<Func, Derived, 0, Size / PacketSize>::run(mat,func)); if (VectorizedSize != Size) res = func(res,redux_novec_unroller<Func, Derived, VectorizedSize, Size-VectorizedSize>::run(mat,func)); @@ -310,6 +309,64 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling> } }; +#ifdef EIGEN_ENABLE_EVALUATORS +// evaluator adaptor +template<typename XprType> +class redux_evaluator +{ +public: + redux_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) {} + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; + + enum { + MaxRowsAtCompileTime = XprType::MaxRowsAtCompileTime, + MaxColsAtCompileTime = XprType::MaxColsAtCompileTime, + // TODO we should not remove DirectAccessBit and rather find an elegant way to query the alignment offset at runtime from the evaluator + Flags = XprType::Flags & ~DirectAccessBit, + IsRowMajor = XprType::IsRowMajor, + SizeAtCompileTime = XprType::SizeAtCompileTime, + InnerSizeAtCompileTime = XprType::InnerSizeAtCompileTime, + CoeffReadCost = XprType::CoeffReadCost + }; + + Index rows() const { return m_xpr.rows(); } + Index cols() const { return m_xpr.cols(); } + Index size() const { return m_xpr.size(); } + Index innerSize() const { return m_xpr.innerSize(); } + Index outerSize() const { return m_xpr.outerSize(); } + + CoeffReturnType coeff(Index row, Index col) const + { return m_evaluator.coeff(row, col); } + + CoeffReturnType coeff(Index index) const + { return m_evaluator.coeff(index); } + + template<int LoadMode> + PacketReturnType packet(Index row, Index col) const + { return m_evaluator.template packet<LoadMode>(row, col); } + + template<int LoadMode> + PacketReturnType packet(Index index) const + { return m_evaluator.template packet<LoadMode>(index); } + + CoeffReturnType coeffByOuterInner(Index outer, Index inner) const + { return m_evaluator.coeff(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); } + + template<int LoadMode> + PacketReturnType packetByOuterInner(Index outer, Index inner) const + { return m_evaluator.template packet<LoadMode>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); } + +protected: + typename internal::evaluator<XprType>::nestedType m_evaluator; + const XprType &m_xpr; +}; +#endif + } // end namespace internal /*************************************************************************** @@ -320,7 +377,7 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling> /** \returns the result of a full redux operation on the whole matrix or vector using \a func * * The template parameter \a BinaryOp is the type of the functor \a func which must be - * an associative operator. Both current STL and TR1 functor styles are handled. + * an associative operator. Both current C++98 and C++11 functor styles are handled. * * \sa DenseBase::sum(), DenseBase::minCoeff(), DenseBase::maxCoeff(), MatrixBase::colwise(), MatrixBase::rowwise() */ @@ -329,9 +386,19 @@ template<typename Func> EIGEN_STRONG_INLINE typename internal::result_of<Func(typename internal::traits<Derived>::Scalar)>::type DenseBase<Derived>::redux(const Func& func) const { + eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix"); +#ifdef EIGEN_TEST_EVALUATORS + + typedef typename internal::redux_evaluator<Derived> ThisEvaluator; + ThisEvaluator thisEval(derived()); + return internal::redux_impl<Func, ThisEvaluator>::run(thisEval, func); + +#else typedef typename internal::remove_all<typename Derived::Nested>::type ThisNested; + return internal::redux_impl<Func, ThisNested> ::run(derived(), func); +#endif } /** \returns the minimum of all coefficients of \c *this. |