diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index 31b361c83..4e873011e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -403,6 +403,87 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg TensorEvaluator<RightArgType, Device> m_rightImpl; }; +// -------------------- CwiseTernaryOp -------------------- + +template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device> +struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device> +{ + typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType; + + enum { + IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned, + PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess & + internal::functor_traits<TernaryOp>::PacketAccess, + Layout = TensorEvaluator<Arg1Type, Device>::Layout, + CoordAccess = false, // to be implemented + RawAccess = false + }; + + EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) + : m_functor(op.functor()), + m_arg1Impl(op.arg1Expression(), device), + m_arg2Impl(op.arg2Expression(), device), + m_arg3Impl(op.arg3Expression(), device) + { + EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions())); + } + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename internal::traits<XprType>::Scalar CoeffReturnType; + typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; + static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; + typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions; + + EIGEN_DEVICE_FUNC const Dimensions& dimensions() const + { + // TODO: use arg2 or arg3 dimensions if they are known at compile time. + return m_arg1Impl.dimensions(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { + m_arg1Impl.evalSubExprsIfNeeded(NULL); + m_arg2Impl.evalSubExprsIfNeeded(NULL); + m_arg3Impl.evalSubExprsIfNeeded(NULL); + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + m_arg1Impl.cleanup(); + m_arg2Impl.cleanup(); + m_arg3Impl.cleanup(); + } + + EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const + { + return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index)); + } + template<int LoadMode> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const + { + return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index), + m_arg2Impl.template packet<LoadMode>(index), + m_arg3Impl.template packet<LoadMode>(index)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost + costPerCoeff(bool vectorized) const { + const double functor_cost = internal::functor_traits<TernaryOp>::Cost; + return m_arg1Impl.costPerCoeff(vectorized) + + m_arg2Impl.costPerCoeff(vectorized) + + m_arg3Impl.costPerCoeff(vectorized) + + TensorOpCost(0, 0, functor_cost, vectorized, PacketSize); + } + + EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; } + + private: + const TernaryOp m_functor; + TensorEvaluator<Arg1Type, Device> m_arg1Impl; + TensorEvaluator<Arg1Type, Device> m_arg2Impl; + TensorEvaluator<Arg3Type, Device> m_arg3Impl; +}; + // -------------------- SelectOp -------------------- |