diff options
Diffstat (limited to 'Eigen/src/Core/CoreEvaluators.h')
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 139 |
1 files changed, 128 insertions, 11 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 7ba92963c..7a5540593 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -337,6 +337,120 @@ protected: // Like Matrix and Array, this is not really a unary expression, so we directly specialize evaluator. // Likewise, there is not need to more sophisticated dispatching here. +template<typename Scalar,typename NullaryOp, + bool has_nullary = has_nullary_operator<NullaryOp>::value, + bool has_unary = has_unary_operator<NullaryOp>::value, + bool has_binary = has_binary_operator<NullaryOp>::value> +struct nullary_wrapper +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j) const { return op(i,j); } + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i) const { return op(i); } + + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j) const { return op.template packetOp<T>(i,j); } + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i) const { return op.template packetOp<T>(i); } +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,true,false,false> +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType=0, IndexType=0) const { return op(); } + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType=0, IndexType=0) const { return op.template packetOp<T>(); } +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,false,false,true> +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j=0) const { return op(i,j); } + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j=0) const { return op.template packetOp<T>(i,j); } +}; + +// We need the following specialization for vector-only functors assigned to a runtime vector, +// for instance, using linspace and assigning a RowVectorXd to a MatrixXd or even a row of a MatrixXd. +// In this case, i==0 and j is used for the actual iteration. +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,false,true,false> +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j) const { + eigen_assert(i==0 || j==0); + return op(i+j); + } + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j) const { + eigen_assert(i==0 || j==0); + return op.template packetOp<T>(i+j); + } + + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i) const { return op(i); } + template <typename T, typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i) const { return op.template packetOp<T>(i); } +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,false,false,false> {}; + +#if 0 && EIGEN_COMP_MSVC>0 +// Disable this ugly workaround. This is now handled in traits<Ref>::match, +// but this piece of code might still become handly if some other weird compilation +// erros pop up again. + +// MSVC exhibits a weird compilation error when +// compiling: +// Eigen::MatrixXf A = MatrixXf::Random(3,3); +// Ref<const MatrixXf> R = 2.f*A; +// and that has_*ary_operator<scalar_constant_op<float>> have not been instantiated yet. +// The "problem" is that evaluator<2.f*A> is instantiated by traits<Ref>::match<2.f*A> +// and at that time has_*ary_operator<T> returns true regardless of T. +// Then nullary_wrapper is badly instantiated as nullary_wrapper<.,.,true,true,true>. +// The trick is thus to defer the proper instantiation of nullary_wrapper when coeff(), +// and packet() are really instantiated as implemented below: + +// This is a simple wrapper around Index to enforce the re-instantiation of +// has_*ary_operator when needed. +template<typename T> struct nullary_wrapper_workaround_msvc { + nullary_wrapper_workaround_msvc(const T&); + operator T()const; +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,true,true,true> +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j) const { + return nullary_wrapper<Scalar,NullaryOp, + has_nullary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_unary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_binary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value>().operator()(op,i,j); + } + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i) const { + return nullary_wrapper<Scalar,NullaryOp, + has_nullary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_unary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_binary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value>().operator()(op,i); + } + + template <typename T, typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j) const { + return nullary_wrapper<Scalar,NullaryOp, + has_nullary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_unary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_binary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value>().template packetOp<T>(op,i,j); + } + template <typename T, typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i) const { + return nullary_wrapper<Scalar,NullaryOp, + has_nullary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_unary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_binary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value>().template packetOp<T>(op,i); + } +}; +#endif // MSVC workaround + template<typename NullaryOp, typename PlainObjectType> struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> > : evaluator_base<CwiseNullaryOp<NullaryOp,PlainObjectType> > @@ -356,41 +470,44 @@ struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> > }; EIGEN_DEVICE_FUNC explicit evaluator(const XprType& n) - : m_functor(n.functor()) + : m_functor(n.functor()), m_wrapper() { EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); } typedef typename XprType::CoeffReturnType CoeffReturnType; + template <typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - CoeffReturnType coeff(Index row, Index col) const + CoeffReturnType coeff(IndexType row, IndexType col) const { - return m_functor(row, col); + return m_wrapper(m_functor, row, col); } + template <typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - CoeffReturnType coeff(Index index) const + CoeffReturnType coeff(IndexType index) const { - return m_functor(index); + return m_wrapper(m_functor,index); } - template<int LoadMode, typename PacketType> + template<int LoadMode, typename PacketType, typename IndexType> EIGEN_STRONG_INLINE - PacketType packet(Index row, Index col) const + PacketType packet(IndexType row, IndexType col) const { - return m_functor.template packetOp<Index,PacketType>(row, col); + return m_wrapper.template packetOp<PacketType>(m_functor, row, col); } - template<int LoadMode, typename PacketType> + template<int LoadMode, typename PacketType, typename IndexType> EIGEN_STRONG_INLINE - PacketType packet(Index index) const + PacketType packet(IndexType index) const { - return m_functor.template packetOp<Index,PacketType>(index); + return m_wrapper.template packetOp<PacketType>(m_functor, index); } protected: const NullaryOp m_functor; + const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper; }; // -------------------- CwiseUnaryOp -------------------- |