diff options
Diffstat (limited to 'unsupported/Eigen/CXX11')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h index cdbafbbb1..44493906d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h @@ -51,7 +51,10 @@ struct nested<TensorConversionOp<TargetType, XprType>, 1, typename eval<TensorCo template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio> -struct PacketConverter { +struct PacketConverter; + +template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket> +struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 1> { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl) {} @@ -109,7 +112,33 @@ struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 4, 1> { }; template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket> -struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 2> { +struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 8, 1> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + PacketConverter(const TensorEvaluator& impl) + : m_impl(impl) {} + + template<int LoadMode, typename Index> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const { + const int SrcPacketSize = internal::unpacket_traits<SrcPacket>::size; + + SrcPacket src1 = m_impl.template packet<LoadMode>(index); + SrcPacket src2 = m_impl.template packet<LoadMode>(index + 1 * SrcPacketSize); + SrcPacket src3 = m_impl.template packet<LoadMode>(index + 2 * SrcPacketSize); + SrcPacket src4 = m_impl.template packet<LoadMode>(index + 3 * SrcPacketSize); + SrcPacket src5 = m_impl.template packet<LoadMode>(index + 4 * SrcPacketSize); + SrcPacket src6 = m_impl.template packet<LoadMode>(index + 5 * SrcPacketSize); + SrcPacket src7 = m_impl.template packet<LoadMode>(index + 6 * SrcPacketSize); + SrcPacket src8 = m_impl.template packet<LoadMode>(index + 7 * SrcPacketSize); + TgtPacket result = internal::pcast<SrcPacket, TgtPacket>(src1, src2, src3, src4, src5, src6, src7, src8); + return result; + } + + private: + const TensorEvaluator& m_impl; +}; + +template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int TgtCoeffRatio> +struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, TgtCoeffRatio> { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl), m_maxIndex(impl.dimensions().TotalSize()) {} |