diff options
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 193 |
1 files changed, 121 insertions, 72 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 0f05ea76e..24fc7835b 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -131,6 +131,27 @@ private: // Here we directly specialize evaluator. This is not really a unary expression, and it is, by definition, dense, // so no need for more sophisticated dispatching. +// this helper permits to completely eliminate m_outerStride if it is known at compiletime. +template<typename Scalar,int OuterStride> class plainobjectbase_evaluator_data { +public: + plainobjectbase_evaluator_data(const Scalar* ptr, Index outerStride) : data(ptr) + { + EIGEN_ONLY_USED_FOR_DEBUG(outerStride); + eigen_internal_assert(outerStride==OuterStride); + } + Index outerStride() const { return OuterStride; } + const Scalar *data; +}; + +template<typename Scalar> class plainobjectbase_evaluator_data<Scalar,Dynamic> { +public: + plainobjectbase_evaluator_data(const Scalar* ptr, Index outerStride) : data(ptr), m_outerStride(outerStride) {} + Index outerStride() const { return m_outerStride; } + const Scalar *data; +protected: + Index m_outerStride; +}; + template<typename Derived> struct evaluator<PlainObjectBase<Derived> > : evaluator_base<Derived> @@ -149,18 +170,21 @@ struct evaluator<PlainObjectBase<Derived> > Flags = traits<Derived>::EvaluatorFlags, Alignment = traits<Derived>::Alignment }; - + enum { + // We do not need to know the outer stride for vectors + OuterStrideAtCompileTime = IsVectorAtCompileTime ? 0 + : int(IsRowMajor) ? ColsAtCompileTime + : RowsAtCompileTime + }; + EIGEN_DEVICE_FUNC evaluator() - : m_data(0), - m_outerStride(IsVectorAtCompileTime ? 0 - : int(IsRowMajor) ? ColsAtCompileTime - : RowsAtCompileTime) + : m_d(0,OuterStrideAtCompileTime) { EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); } - + EIGEN_DEVICE_FUNC explicit evaluator(const PlainObjectType& m) - : m_data(m.data()), m_outerStride(IsVectorAtCompileTime ? 0 : m.outerStride()) + : m_d(m.data(),IsVectorAtCompileTime ? 0 : m.outerStride()) { EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); } @@ -169,30 +193,30 @@ struct evaluator<PlainObjectBase<Derived> > CoeffReturnType coeff(Index row, Index col) const { if (IsRowMajor) - return m_data[row * m_outerStride.value() + col]; + return m_d.data[row * m_d.outerStride() + col]; else - return m_data[row + col * m_outerStride.value()]; + return m_d.data[row + col * m_d.outerStride()]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { - return m_data[index]; + return m_d.data[index]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) { if (IsRowMajor) - return const_cast<Scalar*>(m_data)[row * m_outerStride.value() + col]; + return const_cast<Scalar*>(m_d.data)[row * m_d.outerStride() + col]; else - return const_cast<Scalar*>(m_data)[row + col * m_outerStride.value()]; + return const_cast<Scalar*>(m_d.data)[row + col * m_d.outerStride()]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { - return const_cast<Scalar*>(m_data)[index]; + return const_cast<Scalar*>(m_d.data)[index]; } template<int LoadMode, typename PacketType> @@ -200,16 +224,16 @@ struct evaluator<PlainObjectBase<Derived> > PacketType packet(Index row, Index col) const { if (IsRowMajor) - return ploadt<PacketType, LoadMode>(m_data + row * m_outerStride.value() + col); + return ploadt<PacketType, LoadMode>(m_d.data + row * m_d.outerStride() + col); else - return ploadt<PacketType, LoadMode>(m_data + row + col * m_outerStride.value()); + return ploadt<PacketType, LoadMode>(m_d.data + row + col * m_d.outerStride()); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE PacketType packet(Index index) const { - return ploadt<PacketType, LoadMode>(m_data + index); + return ploadt<PacketType, LoadMode>(m_d.data + index); } template<int StoreMode,typename PacketType> @@ -218,26 +242,22 @@ struct evaluator<PlainObjectBase<Derived> > { if (IsRowMajor) return pstoret<Scalar, PacketType, StoreMode> - (const_cast<Scalar*>(m_data) + row * m_outerStride.value() + col, x); + (const_cast<Scalar*>(m_d.data) + row * m_d.outerStride() + col, x); else return pstoret<Scalar, PacketType, StoreMode> - (const_cast<Scalar*>(m_data) + row + col * m_outerStride.value(), x); + (const_cast<Scalar*>(m_d.data) + row + col * m_d.outerStride(), x); } template<int StoreMode, typename PacketType> EIGEN_STRONG_INLINE void writePacket(Index index, const PacketType& x) { - return pstoret<Scalar, PacketType, StoreMode>(const_cast<Scalar*>(m_data) + index, x); + return pstoret<Scalar, PacketType, StoreMode>(const_cast<Scalar*>(m_d.data) + index, x); } protected: - const Scalar *m_data; - // We do not need to know the outer stride for vectors - variable_if_dynamic<Index, IsVectorAtCompileTime ? 0 - : int(IsRowMajor) ? ColsAtCompileTime - : RowsAtCompileTime> m_outerStride; + plainobjectbase_evaluator_data<Scalar,OuterStrideAtCompileTime> m_d; }; template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols> @@ -535,9 +555,7 @@ struct unary_evaluator<CwiseUnaryOp<UnaryOp, ArgType>, IndexBased > }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - explicit unary_evaluator(const XprType& op) - : m_functor(op.functor()), - m_argImpl(op.nestedExpression()) + explicit unary_evaluator(const XprType& op) : m_d(op) { EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<UnaryOp>::Cost); EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); @@ -548,32 +566,43 @@ struct unary_evaluator<CwiseUnaryOp<UnaryOp, ArgType>, IndexBased > EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { - return m_functor(m_argImpl.coeff(row, col)); + return m_d.func()(m_d.argImpl.coeff(row, col)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { - return m_functor(m_argImpl.coeff(index)); + return m_d.func()(m_d.argImpl.coeff(index)); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const { - return m_functor.packetOp(m_argImpl.template packet<LoadMode, PacketType>(row, col)); + return m_d.func().packetOp(m_d.argImpl.template packet<LoadMode, PacketType>(row, col)); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE PacketType packet(Index index) const { - return m_functor.packetOp(m_argImpl.template packet<LoadMode, PacketType>(index)); + return m_d.func().packetOp(m_d.argImpl.template packet<LoadMode, PacketType>(index)); } protected: - const UnaryOp m_functor; - evaluator<ArgType> m_argImpl; + + // this helper permits to completely eliminate the functor if it is empty + class Data : private UnaryOp + { + public: + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Data(const XprType& xpr) : UnaryOp(xpr.functor()), argImpl(xpr.nestedExpression()) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const UnaryOp& func() const { return static_cast<const UnaryOp&>(*this); } + evaluator<ArgType> argImpl; + }; + + Data m_d; }; // -------------------- CwiseTernaryOp -------------------- @@ -617,11 +646,7 @@ struct ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3>, IndexBased evaluator<Arg3>::Alignment) }; - EIGEN_DEVICE_FUNC explicit ternary_evaluator(const XprType& xpr) - : m_functor(xpr.functor()), - m_arg1Impl(xpr.arg1()), - m_arg2Impl(xpr.arg2()), - m_arg3Impl(xpr.arg3()) + EIGEN_DEVICE_FUNC explicit ternary_evaluator(const XprType& xpr) : m_d(xpr) { EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<TernaryOp>::Cost); EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); @@ -632,38 +657,47 @@ struct ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3>, IndexBased EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { - return m_functor(m_arg1Impl.coeff(row, col), m_arg2Impl.coeff(row, col), m_arg3Impl.coeff(row, col)); + return m_d.func()(m_d.arg1Impl.coeff(row, col), m_d.arg2Impl.coeff(row, col), m_d.arg3Impl.coeff(row, col)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { - return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index)); + return m_d.func()(m_d.arg1Impl.coeff(index), m_d.arg2Impl.coeff(index), m_d.arg3Impl.coeff(index)); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const { - return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(row, col), - m_arg2Impl.template packet<LoadMode,PacketType>(row, col), - m_arg3Impl.template packet<LoadMode,PacketType>(row, col)); + return m_d.func().packetOp(m_d.arg1Impl.template packet<LoadMode,PacketType>(row, col), + m_d.arg2Impl.template packet<LoadMode,PacketType>(row, col), + m_d.arg3Impl.template packet<LoadMode,PacketType>(row, col)); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE PacketType packet(Index index) const { - return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(index), - m_arg2Impl.template packet<LoadMode,PacketType>(index), - m_arg3Impl.template packet<LoadMode,PacketType>(index)); + return m_d.func().packetOp(m_d.arg1Impl.template packet<LoadMode,PacketType>(index), + m_d.arg2Impl.template packet<LoadMode,PacketType>(index), + m_d.arg3Impl.template packet<LoadMode,PacketType>(index)); } protected: - const TernaryOp m_functor; - evaluator<Arg1> m_arg1Impl; - evaluator<Arg2> m_arg2Impl; - evaluator<Arg3> m_arg3Impl; + // this helper permits to completely eliminate the functor if it is empty + struct Data : private TernaryOp + { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Data(const XprType& xpr) : TernaryOp(xpr.functor()), arg1Impl(xpr.arg1()), arg2Impl(xpr.arg2()), arg3Impl(xpr.arg3()) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TernaryOp& func() const { return static_cast<const TernaryOp&>(*this); } + evaluator<Arg1> arg1Impl; + evaluator<Arg2> arg2Impl; + evaluator<Arg3> arg3Impl; + }; + + Data m_d; }; // -------------------- CwiseBinaryOp -------------------- @@ -704,10 +738,7 @@ struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IndexBased, IndexBase Alignment = EIGEN_PLAIN_ENUM_MIN(evaluator<Lhs>::Alignment,evaluator<Rhs>::Alignment) }; - EIGEN_DEVICE_FUNC explicit binary_evaluator(const XprType& xpr) - : m_functor(xpr.functor()), - m_lhsImpl(xpr.lhs()), - m_rhsImpl(xpr.rhs()) + EIGEN_DEVICE_FUNC explicit binary_evaluator(const XprType& xpr) : m_d(xpr) { EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost); EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); @@ -718,35 +749,45 @@ struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IndexBased, IndexBase EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { - return m_functor(m_lhsImpl.coeff(row, col), m_rhsImpl.coeff(row, col)); + return m_d.func()(m_d.lhsImpl.coeff(row, col), m_d.rhsImpl.coeff(row, col)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { - return m_functor(m_lhsImpl.coeff(index), m_rhsImpl.coeff(index)); + return m_d.func()(m_d.lhsImpl.coeff(index), m_d.rhsImpl.coeff(index)); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const { - return m_functor.packetOp(m_lhsImpl.template packet<LoadMode,PacketType>(row, col), - m_rhsImpl.template packet<LoadMode,PacketType>(row, col)); + return m_d.func().packetOp(m_d.lhsImpl.template packet<LoadMode,PacketType>(row, col), + m_d.rhsImpl.template packet<LoadMode,PacketType>(row, col)); } template<int LoadMode, typename PacketType> EIGEN_STRONG_INLINE PacketType packet(Index index) const { - return m_functor.packetOp(m_lhsImpl.template packet<LoadMode,PacketType>(index), - m_rhsImpl.template packet<LoadMode,PacketType>(index)); + return m_d.func().packetOp(m_d.lhsImpl.template packet<LoadMode,PacketType>(index), + m_d.rhsImpl.template packet<LoadMode,PacketType>(index)); } protected: - const BinaryOp m_functor; - evaluator<Lhs> m_lhsImpl; - evaluator<Rhs> m_rhsImpl; + + // this helper permits to completely eliminate the functor if it is empty + struct Data : private BinaryOp + { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Data(const XprType& xpr) : BinaryOp(xpr.functor()), lhsImpl(xpr.lhs()), rhsImpl(xpr.rhs()) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const BinaryOp& func() const { return static_cast<const BinaryOp&>(*this); } + evaluator<Lhs> lhsImpl; + evaluator<Rhs> rhsImpl; + }; + + Data m_d; }; // -------------------- CwiseUnaryView -------------------- @@ -765,9 +806,7 @@ struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType>, IndexBased> Alignment = 0 // FIXME it is not very clear why alignment is necessarily lost... }; - EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& op) - : m_unaryOp(op.functor()), - m_argImpl(op.nestedExpression()) + EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& op) : m_d(op) { EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<UnaryOp>::Cost); EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); @@ -779,30 +818,40 @@ struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType>, IndexBased> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { - return m_unaryOp(m_argImpl.coeff(row, col)); + return m_d.func()(m_d.argImpl.coeff(row, col)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { - return m_unaryOp(m_argImpl.coeff(index)); + return m_d.func()(m_d.argImpl.coeff(index)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) { - return m_unaryOp(m_argImpl.coeffRef(row, col)); + return m_d.func()(m_d.argImpl.coeffRef(row, col)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { - return m_unaryOp(m_argImpl.coeffRef(index)); + return m_d.func()(m_d.argImpl.coeffRef(index)); } protected: - const UnaryOp m_unaryOp; - evaluator<ArgType> m_argImpl; + + // this helper permits to completely eliminate the functor if it is empty + struct Data : private UnaryOp + { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Data(const XprType& xpr) : UnaryOp(xpr.functor()), argImpl(xpr.nestedExpression()) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const UnaryOp& func() const { return static_cast<const UnaryOp&>(*this); } + evaluator<ArgType> argImpl; + }; + + Data m_d; }; // -------------------- Map -------------------- |