diff options
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/Complex.h')
-rw-r--r-- | Eigen/src/Core/arch/AVX512/Complex.h | 92 |
1 files changed, 68 insertions, 24 deletions
diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h index 569ee01ff..9a89dd01f 100644 --- a/Eigen/src/Core/arch/AVX512/Complex.h +++ b/Eigen/src/Core/arch/AVX512/Complex.h @@ -56,6 +56,8 @@ template<> struct unpacket_traits<Packet8cf> { typedef Packet4cf half; }; +template<> EIGEN_STRONG_INLINE Packet8cf ptrue<Packet8cf>(const Packet8cf& a) { return Packet8cf(ptrue(Packet16f(a.v))); } +template<> EIGEN_STRONG_INLINE Packet8cf pnot<Packet8cf>(const Packet8cf& a) { return Packet8cf(pnot(Packet16f(a.v))); } template<> EIGEN_STRONG_INLINE Packet8cf padd<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_add_ps(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet8cf psub<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_sub_ps(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet8cf pnegate(const Packet8cf& a) @@ -67,7 +69,7 @@ template<> EIGEN_STRONG_INLINE Packet8cf pconj(const Packet8cf& a) const __m512 mask = _mm512_castsi512_ps(_mm512_setr_epi32( 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000, 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000)); - return Packet8cf(_mm512_xor_ps(a.v,mask)); + return Packet8cf(pxor(a.v,mask)); } template<> EIGEN_STRONG_INLINE Packet8cf pmul<Packet8cf>(const Packet8cf& a, const Packet8cf& b) @@ -76,10 +78,16 @@ template<> EIGEN_STRONG_INLINE Packet8cf pmul<Packet8cf>(const Packet8cf& a, con return Packet8cf(_mm512_fmaddsub_ps(_mm512_moveldup_ps(a.v), b.v, tmp2)); } -template<> EIGEN_STRONG_INLINE Packet8cf pand <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_and_ps(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet8cf por <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_or_ps(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet8cf pxor <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_xor_ps(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet8cf pandnot<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_andnot_ps(b.v,a.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pand <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pand(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf por <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(por(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pxor <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pxor(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pandnot<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pandnot(a.v,b.v)); } + +template <> +EIGEN_STRONG_INLINE Packet8cf pcmp_eq(const Packet8cf& a, const Packet8cf& b) { + __m512 eq = pcmp_eq<Packet16f>(a.v, b.v); + return Packet8cf(pand(eq, _mm512_permute_ps(eq, 0xB1))); +} template<> EIGEN_STRONG_INLINE Packet8cf pload <Packet8cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet8cf(pload<Packet16f>(&numext::real_ref(*from))); } template<> EIGEN_STRONG_INLINE Packet8cf ploadu<Packet8cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet8cf(ploadu<Packet16f>(&numext::real_ref(*from))); } @@ -125,20 +133,20 @@ template<> EIGEN_STRONG_INLINE Packet8cf preverse(const Packet8cf& a) { template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet8cf>(const Packet8cf& a) { - return predux(padd(Packet4cf(_mm512_extractf32x8_ps(a.v,0)), - Packet4cf(_mm512_extractf32x8_ps(a.v,1)))); + return predux(padd(Packet4cf(extract256<0>(a.v)), + Packet4cf(extract256<1>(a.v)))); } template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet8cf>(const Packet8cf& a) { - return predux_mul(pmul(Packet4cf(_mm512_extractf32x8_ps(a.v, 0)), - Packet4cf(_mm512_extractf32x8_ps(a.v, 1)))); + return predux_mul(pmul(Packet4cf(extract256<0>(a.v)), + Packet4cf(extract256<1>(a.v)))); } template <> EIGEN_STRONG_INLINE Packet4cf predux_half_dowto4<Packet8cf>(const Packet8cf& a) { - __m256 lane0 = _mm512_extractf32x8_ps(a.v, 0); - __m256 lane1 = _mm512_extractf32x8_ps(a.v, 1); + __m256 lane0 = extract256<0>(a.v); + __m256 lane1 = extract256<1>(a.v); __m256 res = _mm256_add_ps(lane0, lane1); return Packet4cf(res); } @@ -264,10 +272,18 @@ template<> EIGEN_STRONG_INLINE Packet4cd pmul<Packet4cd>(const Packet4cd& a, con return Packet4cd(_mm512_fmaddsub_pd(tmp1, b.v, odd)); } -template<> EIGEN_STRONG_INLINE Packet4cd pand <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_and_pd(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet4cd por <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_or_pd(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet4cd pxor <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_xor_pd(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet4cd pandnot<Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_andnot_pd(b.v,a.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd ptrue<Packet4cd>(const Packet4cd& a) { return Packet4cd(ptrue(Packet8d(a.v))); } +template<> EIGEN_STRONG_INLINE Packet4cd pnot<Packet4cd>(const Packet4cd& a) { return Packet4cd(pnot(Packet8d(a.v))); } +template<> EIGEN_STRONG_INLINE Packet4cd pand <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pand(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd por <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(por(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pxor <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pxor(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pandnot<Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pandnot(a.v,b.v)); } + +template <> +EIGEN_STRONG_INLINE Packet4cd pcmp_eq(const Packet4cd& a, const Packet4cd& b) { + __m512d eq = pcmp_eq<Packet8d>(a.v, b.v); + return Packet4cd(pand(eq, _mm512_permute_pd(eq, 0x55))); +} template<> EIGEN_STRONG_INLINE Packet4cd pload <Packet4cd>(const std::complex<double>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet4cd(pload<Packet8d>((const double*)from)); } @@ -294,23 +310,23 @@ template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex< template<> EIGEN_DEVICE_FUNC inline Packet4cd pgather<std::complex<double>, Packet4cd>(const std::complex<double>* from, Index stride) { return Packet4cd(_mm512_insertf64x4(_mm512_castpd256_pd512( - _mm256_insertf128_pd(_mm256_castpd128_pd256(pload<Packet1cd>(from+0*stride).v), pload<Packet1cd>(from+1*stride).v,1)), - _mm256_insertf128_pd(_mm256_castpd128_pd256(pload<Packet1cd>(from+2*stride).v), pload<Packet1cd>(from+3*stride).v,1), 1)); + _mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu<Packet1cd>(from+0*stride).v), ploadu<Packet1cd>(from+1*stride).v,1)), + _mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu<Packet1cd>(from+2*stride).v), ploadu<Packet1cd>(from+3*stride).v,1), 1)); } template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet4cd>(std::complex<double>* to, const Packet4cd& from, Index stride) { __m512i fromi = _mm512_castpd_si512(from.v); double* tod = (double*)(void*)to; - _mm_store_pd(tod+0*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,0)) ); - _mm_store_pd(tod+2*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,1)) ); - _mm_store_pd(tod+4*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,2)) ); - _mm_store_pd(tod+6*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,3)) ); + _mm_storeu_pd(tod+0*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,0)) ); + _mm_storeu_pd(tod+2*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,1)) ); + _mm_storeu_pd(tod+4*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,2)) ); + _mm_storeu_pd(tod+6*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,3)) ); } template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet4cd>(const Packet4cd& a) { - __m128d low = _mm512_extractf64x2_pd(a.v, 0); + __m128d low = extract128<0>(a.v); EIGEN_ALIGN16 double res[2]; _mm_store_pd(res, low); return std::complex<double>(res[0],res[1]); @@ -392,12 +408,40 @@ template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip<Packet4cd>(const Packet4cd& x EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8cf,4>& kernel) { - ptranspose(reinterpret_cast<PacketBlock<Packet8d,4>&>(kernel)); + PacketBlock<Packet8d,4> pb; + + pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v); + pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v); + pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v); + pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v); + ptranspose(pb); + kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]); + kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]); + kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]); + kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]); } EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8cf,8>& kernel) { - ptranspose(reinterpret_cast<PacketBlock<Packet8d,8>&>(kernel)); + PacketBlock<Packet8d,8> pb; + + pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v); + pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v); + pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v); + pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v); + pb.packet[4] = _mm512_castps_pd(kernel.packet[4].v); + pb.packet[5] = _mm512_castps_pd(kernel.packet[5].v); + pb.packet[6] = _mm512_castps_pd(kernel.packet[6].v); + pb.packet[7] = _mm512_castps_pd(kernel.packet[7].v); + ptranspose(pb); + kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]); + kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]); + kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]); + kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]); + kernel.packet[4].v = _mm512_castpd_ps(pb.packet[4]); + kernel.packet[5].v = _mm512_castpd_ps(pb.packet[5]); + kernel.packet[6].v = _mm512_castpd_ps(pb.packet[6]); + kernel.packet[7].v = _mm512_castpd_ps(pb.packet[7]); } EIGEN_DEVICE_FUNC inline void |