diff options
Diffstat (limited to 'Eigen/src/Core/CoreEvaluators.h')
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 298 |
1 files changed, 259 insertions, 39 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 932178f53..00c079bd8 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -41,11 +41,20 @@ template<> struct storage_kind_to_shape<TranspositionsStorage> { typedef Transp // We currently distinguish the following kind of evaluators: // - unary_evaluator for expressions taking only one arguments (CwiseUnaryOp, CwiseUnaryView, Transpose, MatrixWrapper, ArrayWrapper, Reverse, Replicate) // - binary_evaluator for expression taking two arguments (CwiseBinaryOp) +// - ternary_evaluator for expression taking three arguments (CwiseTernaryOp) // - product_evaluator for linear algebra products (Product); special case of binary_evaluator because it requires additional tags for dispatching. // - mapbase_evaluator for Map, Block, Ref // - block_evaluator for Block (special dispatching to a mapbase_evaluator or unary_evaluator) template< typename T, + typename Arg1Kind = typename evaluator_traits<typename T::Arg1>::Kind, + typename Arg2Kind = typename evaluator_traits<typename T::Arg2>::Kind, + typename Arg3Kind = typename evaluator_traits<typename T::Arg3>::Kind, + typename Arg1Scalar = typename traits<typename T::Arg1>::Scalar, + typename Arg2Scalar = typename traits<typename T::Arg2>::Scalar, + typename Arg3Scalar = typename traits<typename T::Arg3>::Scalar> struct ternary_evaluator; + +template< typename T, typename LhsKind = typename evaluator_traits<typename T::Lhs>::Kind, typename RhsKind = typename evaluator_traits<typename T::Rhs>::Kind, typename LhsScalar = typename traits<typename T::Lhs>::Scalar, @@ -328,6 +337,120 @@ protected: // Like Matrix and Array, this is not really a unary expression, so we directly specialize evaluator. // Likewise, there is not need to more sophisticated dispatching here. +template<typename Scalar,typename NullaryOp, + bool has_nullary = has_nullary_operator<NullaryOp>::value, + bool has_unary = has_unary_operator<NullaryOp>::value, + bool has_binary = has_binary_operator<NullaryOp>::value> +struct nullary_wrapper +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j) const { return op(i,j); } + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i) const { return op(i); } + + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j) const { return op.template packetOp<T>(i,j); } + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i) const { return op.template packetOp<T>(i); } +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,true,false,false> +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType=0, IndexType=0) const { return op(); } + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType=0, IndexType=0) const { return op.template packetOp<T>(); } +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,false,false,true> +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j=0) const { return op(i,j); } + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j=0) const { return op.template packetOp<T>(i,j); } +}; + +// We need the following specialization for vector-only functors assigned to a runtime vector, +// for instance, using linspace and assigning a RowVectorXd to a MatrixXd or even a row of a MatrixXd. +// In this case, i==0 and j is used for the actual iteration. +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,false,true,false> +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j) const { + eigen_assert(i==0 || j==0); + return op(i+j); + } + template <typename T, typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j) const { + eigen_assert(i==0 || j==0); + return op.template packetOp<T>(i+j); + } + + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i) const { return op(i); } + template <typename T, typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i) const { return op.template packetOp<T>(i); } +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,false,false,false> {}; + +#if 0 && EIGEN_COMP_MSVC>0 +// Disable this ugly workaround. This is now handled in traits<Ref>::match, +// but this piece of code might still become handly if some other weird compilation +// erros pop up again. + +// MSVC exhibits a weird compilation error when +// compiling: +// Eigen::MatrixXf A = MatrixXf::Random(3,3); +// Ref<const MatrixXf> R = 2.f*A; +// and that has_*ary_operator<scalar_constant_op<float>> have not been instantiated yet. +// The "problem" is that evaluator<2.f*A> is instantiated by traits<Ref>::match<2.f*A> +// and at that time has_*ary_operator<T> returns true regardless of T. +// Then nullary_wrapper is badly instantiated as nullary_wrapper<.,.,true,true,true>. +// The trick is thus to defer the proper instantiation of nullary_wrapper when coeff(), +// and packet() are really instantiated as implemented below: + +// This is a simple wrapper around Index to enforce the re-instantiation of +// has_*ary_operator when needed. +template<typename T> struct nullary_wrapper_workaround_msvc { + nullary_wrapper_workaround_msvc(const T&); + operator T()const; +}; + +template<typename Scalar,typename NullaryOp> +struct nullary_wrapper<Scalar,NullaryOp,true,true,true> +{ + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i, IndexType j) const { + return nullary_wrapper<Scalar,NullaryOp, + has_nullary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_unary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_binary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value>().operator()(op,i,j); + } + template <typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, IndexType i) const { + return nullary_wrapper<Scalar,NullaryOp, + has_nullary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_unary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_binary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value>().operator()(op,i); + } + + template <typename T, typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i, IndexType j) const { + return nullary_wrapper<Scalar,NullaryOp, + has_nullary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_unary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_binary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value>().template packetOp<T>(op,i,j); + } + template <typename T, typename IndexType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, IndexType i) const { + return nullary_wrapper<Scalar,NullaryOp, + has_nullary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_unary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value, + has_binary_operator<NullaryOp,nullary_wrapper_workaround_msvc<IndexType> >::value>().template packetOp<T>(op,i); + } +}; +#endif // MSVC workaround + template<typename NullaryOp, typename PlainObjectType> struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> > : evaluator_base<CwiseNullaryOp<NullaryOp,PlainObjectType> > @@ -347,41 +470,44 @@ struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> > }; EIGEN_DEVICE_FUNC explicit evaluator(const XprType& n) - : m_functor(n.functor()) + : m_functor(n.functor()), m_wrapper() { EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); } typedef typename XprType::CoeffReturnType CoeffReturnType; + template <typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - CoeffReturnType coeff(Index row, Index col) const + CoeffReturnType coeff(IndexType row, IndexType col) const { - return m_functor(row, col); + return m_wrapper(m_functor, row, col); } + template <typename IndexType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - CoeffReturnType coeff(Index index) const + CoeffReturnType coeff(IndexType index) const { - return m_functor(index); + return m_wrapper(m_functor,index); } - template<int LoadMode, typename PacketType> + template<int LoadMode, typename PacketType, typename IndexType> EIGEN_STRONG_INLINE - PacketType packet(Index row, Index col) const + PacketType packet(IndexType row, IndexType col) const { - return m_functor.template packetOp<Index,PacketType>(row, col); + return m_wrapper.template packetOp<PacketType>(m_functor, row, col); } - template<int LoadMode, typename PacketType> + template<int LoadMode, typename PacketType, typename IndexType> EIGEN_STRONG_INLINE - PacketType packet(Index index) const + PacketType packet(IndexType index) const { - return m_functor.template packetOp<Index,PacketType>(index); + return m_wrapper.template packetOp<PacketType>(m_functor, index); } protected: const NullaryOp m_functor; + const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper; }; // -------------------- CwiseUnaryOp -------------------- @@ -442,6 +568,96 @@ protected: evaluator<ArgType> m_argImpl; }; +// -------------------- CwiseTernaryOp -------------------- + +// this is a ternary expression +template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> +struct evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > + : public ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > +{ + typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> XprType; + typedef ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > Base; + + EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : Base(xpr) {} +}; + +template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> +struct ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3>, IndexBased, IndexBased> + : evaluator_base<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > +{ + typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> XprType; + + enum { + CoeffReadCost = evaluator<Arg1>::CoeffReadCost + evaluator<Arg2>::CoeffReadCost + evaluator<Arg3>::CoeffReadCost + functor_traits<TernaryOp>::Cost, + + Arg1Flags = evaluator<Arg1>::Flags, + Arg2Flags = evaluator<Arg2>::Flags, + Arg3Flags = evaluator<Arg3>::Flags, + SameType = is_same<typename Arg1::Scalar,typename Arg2::Scalar>::value && is_same<typename Arg1::Scalar,typename Arg3::Scalar>::value, + StorageOrdersAgree = (int(Arg1Flags)&RowMajorBit)==(int(Arg2Flags)&RowMajorBit) && (int(Arg1Flags)&RowMajorBit)==(int(Arg3Flags)&RowMajorBit), + Flags0 = (int(Arg1Flags) | int(Arg2Flags) | int(Arg3Flags)) & ( + HereditaryBits + | (int(Arg1Flags) & int(Arg2Flags) & int(Arg3Flags) & + ( (StorageOrdersAgree ? LinearAccessBit : 0) + | (functor_traits<TernaryOp>::PacketAccess && StorageOrdersAgree && SameType ? PacketAccessBit : 0) + ) + ) + ), + Flags = (Flags0 & ~RowMajorBit) | (Arg1Flags & RowMajorBit), + Alignment = EIGEN_PLAIN_ENUM_MIN( + EIGEN_PLAIN_ENUM_MIN(evaluator<Arg1>::Alignment, evaluator<Arg2>::Alignment), + evaluator<Arg3>::Alignment) + }; + + EIGEN_DEVICE_FUNC explicit ternary_evaluator(const XprType& xpr) + : m_functor(xpr.functor()), + m_arg1Impl(xpr.arg1()), + m_arg2Impl(xpr.arg2()), + m_arg3Impl(xpr.arg3()) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<TernaryOp>::Cost); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + typedef typename XprType::CoeffReturnType CoeffReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index row, Index col) const + { + return m_functor(m_arg1Impl.coeff(row, col), m_arg2Impl.coeff(row, col), m_arg3Impl.coeff(row, col)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + CoeffReturnType coeff(Index index) const + { + return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index)); + } + + template<int LoadMode, typename PacketType> + EIGEN_STRONG_INLINE + PacketType packet(Index row, Index col) const + { + return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(row, col), + m_arg2Impl.template packet<LoadMode,PacketType>(row, col), + m_arg3Impl.template packet<LoadMode,PacketType>(row, col)); + } + + template<int LoadMode, typename PacketType> + EIGEN_STRONG_INLINE + PacketType packet(Index index) const + { + return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(index), + m_arg2Impl.template packet<LoadMode,PacketType>(index), + m_arg3Impl.template packet<LoadMode,PacketType>(index)); + } + +protected: + const TernaryOp m_functor; + evaluator<Arg1> m_arg1Impl; + evaluator<Arg2> m_arg2Impl; + evaluator<Arg3> m_arg3Impl; +}; + // -------------------- CwiseBinaryOp -------------------- // this is a binary expression @@ -601,73 +817,79 @@ struct mapbase_evaluator : evaluator_base<Derived> ColsAtCompileTime = XprType::ColsAtCompileTime, CoeffReadCost = NumTraits<Scalar>::ReadCost }; - + EIGEN_DEVICE_FUNC explicit mapbase_evaluator(const XprType& map) - : m_data(const_cast<PointerType>(map.data())), - m_xpr(map) + : m_data(const_cast<PointerType>(map.data())), + m_innerStride(map.innerStride()), + m_outerStride(map.outerStride()) { EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(evaluator<Derived>::Flags&PacketAccessBit, internal::inner_stride_at_compile_time<Derived>::ret==1), PACKET_ACCESS_REQUIRES_TO_HAVE_INNER_STRIDE_FIXED_TO_1); EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); } - + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { - return m_data[col * m_xpr.colStride() + row * m_xpr.rowStride()]; + return m_data[col * colStride() + row * rowStride()]; } - + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { - return m_data[index * m_xpr.innerStride()]; + return m_data[index * m_innerStride.value()]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) { - return m_data[col * m_xpr.colStride() + row * m_xpr.rowStride()]; + return m_data[col * colStride() + row * rowStride()]; } - + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { - return m_data[index * m_xpr.innerStride()]; + return m_data[index * m_innerStride.value()]; } - + template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE - PacketType packet(Index row, Index col) const + PacketType packet(Index row, Index col) const { - PointerType ptr = m_data + row * m_xpr.rowStride() + col * m_xpr.colStride(); + PointerType ptr = m_data + row * rowStride() + col * colStride(); return internal::ploadt<PacketType, LoadMode>(ptr); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE - PacketType packet(Index index) const + PacketType packet(Index index) const { - return internal::ploadt<PacketType, LoadMode>(m_data + index * m_xpr.innerStride()); + return internal::ploadt<PacketType, LoadMode>(m_data + index * m_innerStride.value()); } - + template<int StoreMode, typename PacketType> EIGEN_STRONG_INLINE - void writePacket(Index row, Index col, const PacketType& x) + void writePacket(Index row, Index col, const PacketType& x) { - PointerType ptr = m_data + row * m_xpr.rowStride() + col * m_xpr.colStride(); + PointerType ptr = m_data + row * rowStride() + col * colStride(); return internal::pstoret<Scalar, PacketType, StoreMode>(ptr, x); } - + template<int StoreMode, typename PacketType> EIGEN_STRONG_INLINE - void writePacket(Index index, const PacketType& x) + void writePacket(Index index, const PacketType& x) { - internal::pstoret<Scalar, PacketType, StoreMode>(m_data + index * m_xpr.innerStride(), x); + internal::pstoret<Scalar, PacketType, StoreMode>(m_data + index * m_innerStride.value(), x); } - protected: + EIGEN_DEVICE_FUNC + inline Index rowStride() const { return XprType::IsRowMajor ? m_outerStride.value() : m_innerStride.value(); } + EIGEN_DEVICE_FUNC + inline Index colStride() const { return XprType::IsRowMajor ? m_innerStride.value() : m_outerStride.value(); } + PointerType m_data; - const XprType& m_xpr; + const internal::variable_if_dynamic<Index, XprType::InnerStrideAtCompileTime> m_innerStride; + const internal::variable_if_dynamic<Index, XprType::OuterStrideAtCompileTime> m_outerStride; }; template<typename PlainObjectType, int MapOptions, typename StrideType> @@ -755,9 +977,7 @@ struct evaluator<Block<ArgType, BlockRows, BlockCols, InnerPanel> > OuterStrideAtCompileTime = HasSameStorageOrderAsArgType ? int(outer_stride_at_compile_time<ArgType>::ret) : int(inner_stride_at_compile_time<ArgType>::ret), - MaskPacketAccessBit = (InnerSize == Dynamic || (InnerSize % packet_traits<Scalar>::size) == 0) - && (InnerStrideAtCompileTime == 1) - ? PacketAccessBit : 0, + MaskPacketAccessBit = (InnerStrideAtCompileTime == 1) ? PacketAccessBit : 0, FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1 || (InnerPanel && (evaluator<ArgType>::Flags&LinearAccessBit))) ? LinearAccessBit : 0, FlagsRowMajorBit = XprType::Flags&RowMajorBit, @@ -884,7 +1104,7 @@ struct block_evaluator<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDirectAc : mapbase_evaluator<XprType, typename XprType::PlainObject>(block) { // TODO: for the 3.3 release, this should be turned to an internal assertion, but let's keep it as is for the beta lifetime - eigen_assert(((size_t(block.data()) % EIGEN_PLAIN_ENUM_MAX(1,evaluator<XprType>::Alignment)) == 0) && "data is not aligned"); + eigen_assert(((internal::UIntPtr(block.data()) % EIGEN_PLAIN_ENUM_MAX(1,evaluator<XprType>::Alignment)) == 0) && "data is not aligned"); } }; @@ -1325,7 +1545,7 @@ struct evaluator<Diagonal<ArgType, DiagIndex> > enum { CoeffReadCost = evaluator<ArgType>::CoeffReadCost, - Flags = (unsigned int)evaluator<ArgType>::Flags & (HereditaryBits | LinearAccessBit | DirectAccessBit) & ~RowMajorBit, + Flags = (unsigned int)(evaluator<ArgType>::Flags & (HereditaryBits | DirectAccessBit) & ~RowMajorBit) | LinearAccessBit, Alignment = 0 }; |