diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h | 51 |
1 files changed, 42 insertions, 9 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h index 3ca7daf32..a96776a77 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h @@ -25,7 +25,6 @@ struct traits<TensorConversionOp<TargetType, XprType> > { // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef TargetType Scalar; - typedef typename packet_traits<Scalar>::type Packet; typedef typename traits<XprType>::StorageKind StorageKind; typedef typename traits<XprType>::Index Index; typedef typename XprType::Nested Nested; @@ -86,6 +85,27 @@ struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 2, 1> { const TensorEvaluator& m_impl; }; +template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket> +struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 4, 1> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + PacketConverter(const TensorEvaluator& impl) + : m_impl(impl) {} + + template<int LoadMode, typename Index> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const { + const int SrcPacketSize = internal::unpacket_traits<SrcPacket>::size; + + SrcPacket src1 = m_impl.template packet<LoadMode>(index); + SrcPacket src2 = m_impl.template packet<LoadMode>(index + SrcPacketSize); + SrcPacket src3 = m_impl.template packet<LoadMode>(index + 2 * SrcPacketSize); + SrcPacket src4 = m_impl.template packet<LoadMode>(index + 3 * SrcPacketSize); + TgtPacket result = internal::pcast<SrcPacket, TgtPacket>(src1, src2, src3, src4); + return result; + } + + private: + const TensorEvaluator& m_impl; +}; template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket> struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 2> { @@ -104,9 +124,12 @@ struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 2> { return internal::pcast<SrcPacket, TgtPacket>(m_impl.template packet<Unaligned>(index)); } else { const int TgtPacketSize = internal::unpacket_traits<TgtPacket>::size; + typedef typename internal::unpacket_traits<SrcPacket>::type SrcType; + typedef typename internal::unpacket_traits<TgtPacket>::type TgtType; + internal::scalar_cast_op<SrcType, TgtType> converter; EIGEN_ALIGN_MAX typename internal::unpacket_traits<TgtPacket>::type values[TgtPacketSize]; for (int i = 0; i < TgtPacketSize; ++i) { - values[i] = m_impl.coeff(index+i); + values[i] = converter(m_impl.coeff(index+i)); } TgtPacket rslt = internal::pload<TgtPacket>(values); return rslt; @@ -123,12 +146,10 @@ class TensorConversionOp : public TensorBase<TensorConversionOp<TargetType, XprT { public: typedef typename internal::traits<TensorConversionOp>::Scalar Scalar; - typedef typename internal::traits<TensorConversionOp>::Packet Packet; typedef typename internal::traits<TensorConversionOp>::StorageKind StorageKind; typedef typename internal::traits<TensorConversionOp>::Index Index; typedef typename internal::nested<TensorConversionOp>::type Nested; typedef Scalar CoeffReturnType; - typedef Packet PacketReturnType; typedef typename NumTraits<Scalar>::Real RealScalar; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConversionOp(const XprType& xpr) @@ -142,6 +163,18 @@ class TensorConversionOp : public TensorBase<TensorConversionOp<TargetType, XprT typename XprType::Nested m_xpr; }; +template <bool SameType, typename Eval, typename Scalar> struct ConversionSubExprEval { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static bool run(Eval& impl, Scalar*) { + impl.evalSubExprsIfNeeded(NULL); + return true; + } +}; + +template <typename Eval, typename Scalar> struct ConversionSubExprEval<true, Eval, Scalar> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static bool run(Eval& impl, Scalar* data) { + return impl.evalSubExprsIfNeeded(data); + } +}; @@ -155,13 +188,14 @@ struct TensorEvaluator<const TensorConversionOp<TargetType, ArgType>, Device> typedef TargetType Scalar; typedef TargetType CoeffReturnType; typedef typename internal::remove_all<typename internal::traits<ArgType>::Scalar>::type SrcType; - typedef typename internal::traits<XprType>::Packet PacketReturnType; - typedef typename internal::packet_traits<SrcType>::type PacketSourceType; + typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; + typedef typename PacketType<SrcType, Device>::type PacketSourceType; enum { IsAligned = false, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess && internal::type_casting_traits<SrcType, TargetType>::VectorizedCast, Layout = TensorEvaluator<ArgType, Device>::Layout, + RawAccess = false }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -171,10 +205,9 @@ struct TensorEvaluator<const TensorConversionOp<TargetType, ArgType>, Device> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_impl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { - m_impl.evalSubExprsIfNeeded(NULL); - return true; + return ConversionSubExprEval<internal::is_same<TargetType, SrcType>::value, TensorEvaluator<ArgType, Device>, Scalar>::run(m_impl, data); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() |