From 9cb8771e9c4a1f44ba59741c9fac495d1872bb25 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 25 Jun 2020 14:31:16 -0700 Subject: Fix tensor casts for large packets and casts to/from std::complex The original tensor casts were only defined for `SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the missing 1:N and 8:1. We also add casting `Eigen::half` to/from `std::complex`, which was missing to make it consistent with `Eigen:bfloat16`, and generalize the overload to work for any complex type. Tests were added to `basicstuff`, `packetmath`, and `cxx11_tensor_casts` to test all cast configurations. --- .../Eigen/CXX11/src/Tensor/TensorConversion.h | 33 ++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) (limited to 'unsupported/Eigen/CXX11') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h index cdbafbbb1..44493906d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h @@ -51,7 +51,10 @@ struct nested, 1, typename eval -struct PacketConverter { +struct PacketConverter; + +template +struct PacketConverter { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl) {} @@ -109,7 +112,33 @@ struct PacketConverter { }; template -struct PacketConverter { +struct PacketConverter { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + PacketConverter(const TensorEvaluator& impl) + : m_impl(impl) {} + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const { + const int SrcPacketSize = internal::unpacket_traits::size; + + SrcPacket src1 = m_impl.template packet(index); + SrcPacket src2 = m_impl.template packet(index + 1 * SrcPacketSize); + SrcPacket src3 = m_impl.template packet(index + 2 * SrcPacketSize); + SrcPacket src4 = m_impl.template packet(index + 3 * SrcPacketSize); + SrcPacket src5 = m_impl.template packet(index + 4 * SrcPacketSize); + SrcPacket src6 = m_impl.template packet(index + 5 * SrcPacketSize); + SrcPacket src7 = m_impl.template packet(index + 6 * SrcPacketSize); + SrcPacket src8 = m_impl.template packet(index + 7 * SrcPacketSize); + TgtPacket result = internal::pcast(src1, src2, src3, src4, src5, src6, src7, src8); + return result; + } + + private: + const TensorEvaluator& m_impl; +}; + +template +struct PacketConverter { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl), m_maxIndex(impl.dimensions().TotalSize()) {} -- cgit v1.2.3