From 3c9add6598cc35e5317788627dfa81f517e89e07 Mon Sep 17 00:00:00 2001 From: Mark D Ryan Date: Fri, 11 Jan 2019 14:02:09 +0100 Subject: Remove reinterpret_cast from AVX512 complex implementation The reinterpret_casts used in ptranspose(PacketBlock&) ptranspose(PacketBlock&) don't appear to be working correctly. They're used to convert the kernel parameters to PacketBlock& so that the complex number versions of ptranspose can be written using the existing double implementations. Unfortunately, they don't seem to work and are responsible for 9 unit test failures in the AVX512 build of tensorflow master. This commit fixes the issue by manually initialising PacketBlock variables with the contents of the kernel parameter before calling the double version of ptranspose, and then copying the resulting values back into the kernel parameter before returning. --- Eigen/src/Core/arch/AVX512/Complex.h | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h index 42cdfcd25..9c4ee1235 100644 --- a/Eigen/src/Core/arch/AVX512/Complex.h +++ b/Eigen/src/Core/arch/AVX512/Complex.h @@ -390,12 +390,40 @@ template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip(const Packet4cd& x EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { - ptranspose(reinterpret_cast&>(kernel)); + PacketBlock pb; + + pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v); + pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v); + pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v); + pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v); + ptranspose(pb); + kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]); + kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]); + kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]); + kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]); } EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { - ptranspose(reinterpret_cast&>(kernel)); + PacketBlock pb; + + pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v); + pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v); + pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v); + pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v); + pb.packet[4] = _mm512_castps_pd(kernel.packet[4].v); + pb.packet[5] = _mm512_castps_pd(kernel.packet[5].v); + pb.packet[6] = _mm512_castps_pd(kernel.packet[6].v); + pb.packet[7] = _mm512_castps_pd(kernel.packet[7].v); + ptranspose(pb); + kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]); + kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]); + kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]); + kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]); + kernel.packet[4].v = _mm512_castpd_ps(pb.packet[4]); + kernel.packet[5].v = _mm512_castpd_ps(pb.packet[5]); + kernel.packet[6].v = _mm512_castpd_ps(pb.packet[6]); + kernel.packet[7].v = _mm512_castpd_ps(pb.packet[7]); } EIGEN_DEVICE_FUNC inline void -- cgit v1.2.3