aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Mark D Ryan <mark.d.ryan@intel.com>2019-01-11 14:02:09 +0100
committerGravatar Mark D Ryan <mark.d.ryan@intel.com>2019-01-11 14:02:09 +0100
commit3c9add6598cc35e5317788627dfa81f517e89e07 (patch)
tree71b030f5c999821a6d73c693cc3bcbd29c94a040 /Eigen
parent0522460a0d01d4253183349a49144b5ad8ba2f9f (diff)
Remove reinterpret_cast from AVX512 complex implementation
The reinterpret_casts used in ptranspose(PacketBlock<Packet8cf,4>&) ptranspose(PacketBlock<Packet8cf,8>&) don't appear to be working correctly. They're used to convert the kernel parameters to PacketBlock<Packet8d,T>& so that the complex number versions of ptranspose can be written using the existing double implementations. Unfortunately, they don't seem to work and are responsible for 9 unit test failures in the AVX512 build of tensorflow master. This commit fixes the issue by manually initialising PacketBlock<Packet8d,T> variables with the contents of the kernel parameter before calling the double version of ptranspose, and then copying the resulting values back into the kernel parameter before returning.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/Core/arch/AVX512/Complex.h32
1 files changed, 30 insertions, 2 deletions
diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h
index 42cdfcd25..9c4ee1235 100644
--- a/Eigen/src/Core/arch/AVX512/Complex.h
+++ b/Eigen/src/Core/arch/AVX512/Complex.h
@@ -390,12 +390,40 @@ template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip<Packet4cd>(const Packet4cd& x
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8cf,4>& kernel) {
- ptranspose(reinterpret_cast<PacketBlock<Packet8d,4>&>(kernel));
+ PacketBlock<Packet8d,4> pb;
+
+ pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
+ pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
+ pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
+ pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
+ ptranspose(pb);
+ kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
+ kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
+ kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
+ kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8cf,8>& kernel) {
- ptranspose(reinterpret_cast<PacketBlock<Packet8d,8>&>(kernel));
+ PacketBlock<Packet8d,8> pb;
+
+ pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
+ pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
+ pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
+ pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
+ pb.packet[4] = _mm512_castps_pd(kernel.packet[4].v);
+ pb.packet[5] = _mm512_castps_pd(kernel.packet[5].v);
+ pb.packet[6] = _mm512_castps_pd(kernel.packet[6].v);
+ pb.packet[7] = _mm512_castps_pd(kernel.packet[7].v);
+ ptranspose(pb);
+ kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
+ kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
+ kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
+ kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
+ kernel.packet[4].v = _mm512_castpd_ps(pb.packet[4]);
+ kernel.packet[5].v = _mm512_castpd_ps(pb.packet[5]);
+ kernel.packet[6].v = _mm512_castpd_ps(pb.packet[6]);
+ kernel.packet[7].v = _mm512_castpd_ps(pb.packet[7]);
}
EIGEN_DEVICE_FUNC inline void