diff options
Diffstat (limited to 'Eigen/src/Core/CoreEvaluators.h')
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 61 |
1 files changed, 56 insertions, 5 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 7ba92963c..01678c6c2 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -337,6 +337,56 @@ protected: // Like Matrix and Array, this is not really a unary expression, so we directly specialize evaluator. // Likewise, there is not need to more sophisticated dispatching here. +template<typename Scalar,typename NullaryOp, + bool has_nullary = has_nullary_operator<NullaryOp>::value, + bool has_unary = has_unary_operator<NullaryOp>::value, + bool has_binary = has_binary_operator<NullaryOp>::value> +struct nullary_wrapper +{ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i, Index j) const { return op(i,j); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i) const { return op(i); } + + template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j) const { return op.template packetOp<T>(i,j); } + template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i) const { return op.template packetOp<T>(i); } +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,true,false,false> +{ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, ...) const { return op(); } + template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, ...) const { return op.template packetOp<T>(); } +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,false,false,true> +{ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i, Index j=0) const { return op(i,j); } + template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j=0) const { return op.template packetOp<T>(i,j); } +}; + +// We need the following specialization for vector-only functors assigned to a runtime vector, +// for instance, using linspace and assigning a RowVectorXd to a MatrixXd or even a row of a MatrixXd. +// In this case, i==0 and j is used for the actual iteration. +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,false,true,false> + : nullary_wrapper<Scalar,NullaryOp,false,true,true> // to get the identity wrapper +{ + typedef nullary_wrapper<Scalar,NullaryOp,false,true,true> base; + using base::operator(); + using base::packetOp; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i, Index j) const { + eigen_assert(i==0 || j==0); + return op(i+j); + } + template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j) const { + eigen_assert(i==0 || j==0); + return op.template packetOp<T>(i+j); + } +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,false,false,false> {}; + template<typename NullaryOp, typename PlainObjectType> struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> > : evaluator_base<CwiseNullaryOp<NullaryOp,PlainObjectType> > @@ -356,7 +406,7 @@ struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> > }; EIGEN_DEVICE_FUNC explicit evaluator(const XprType& n) - : m_functor(n.functor()) + : m_functor(n.functor()), m_wrapper() { EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); } @@ -366,31 +416,32 @@ struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> > EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { - return m_functor(row, col); + return m_wrapper(m_functor, row, col); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { - return m_functor(index); + return m_wrapper(m_functor,index); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const { - return m_functor.template packetOp<Index,PacketType>(row, col); + return m_wrapper.template packetOp<PacketType>(m_functor,row, col); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE PacketType packet(Index index) const { - return m_functor.template packetOp<Index,PacketType>(index); + return m_wrapper.template packetOp<PacketType>(m_functor,index); } protected: const NullaryOp m_functor; + const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper; }; // -------------------- CwiseUnaryOp -------------------- |