diff options
-rw-r--r-- | Eigen/src/Core/arch/AVX/PacketMath.h | 81 | ||||
-rw-r--r-- | test/packetmath.cpp | 10 |
2 files changed, 57 insertions, 34 deletions
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index ae111c671..b68351356 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -870,14 +870,16 @@ template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, cons } // Packet math for Eigen::half +// TODO(cantonios): add missing packet ops +// - pabs, pmin, pmax, plset, pround, print, pceil, pfloor, pcmp_lt, pcmp_le, pcmp_lt_or_nan template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; }; template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) { - return _mm_set1_epi16(from.x); + return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from)); } template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) { - return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm_extract_epi16(from, 0))); + return numext::bit_cast<Eigen::half>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0))); } template<> EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) { @@ -898,17 +900,17 @@ template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const template<> EIGEN_STRONG_INLINE Packet8h ploaddup<Packet8h>(const Eigen::half* from) { - unsigned short a = from[0].x; - unsigned short b = from[1].x; - unsigned short c = from[2].x; - unsigned short d = from[3].x; + const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); + const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); + const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]); + const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]); return _mm_set_epi16(d, d, c, c, b, b, a, a); } template<> EIGEN_STRONG_INLINE Packet8h ploadquad<Packet8h>(const Eigen::half* from) { - unsigned short a = from[0].x; - unsigned short b = from[1].x; + const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); + const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); return _mm_set_epi16(b, b, b, b, a, a, a, a); } @@ -937,16 +939,15 @@ EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { #else EIGEN_ALIGN32 float aux[8]; pstore(aux, a); - Eigen::half h0(aux[0]); - Eigen::half h1(aux[1]); - Eigen::half h2(aux[2]); - Eigen::half h3(aux[3]); - Eigen::half h4(aux[4]); - Eigen::half h5(aux[5]); - Eigen::half h6(aux[6]); - Eigen::half h7(aux[7]); - - return _mm_set_epi16(h7.x, h6.x, h5.x, h4.x, h3.x, h2.x, h1.x, h0.x); + const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[0])); + const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[1])); + const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[2])); + const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[3])); + const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[4])); + const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[5])); + const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[6])); + const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[7])); + return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); #endif } @@ -985,7 +986,7 @@ template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; } template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) { - Packet8h sign_mask = _mm_set1_epi16(static_cast<unsigned short>(0x8000)); + Packet8h sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); return _mm_xor_si128(a, sign_mask); } @@ -1019,7 +1020,15 @@ template<> EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const template<> EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride) { - return _mm_set_epi16(from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x, from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x); + const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]); + const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]); + const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]); + const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]); + const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]); + const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]); + const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]); + const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]); + return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); } template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride) @@ -1178,7 +1187,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q); __m256i nan = _mm256_set1_epi32(0x7fc0); t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask)); - // output.value = static_cast<uint16_t>(input); + // output = numext::bit_cast<uint16_t>(input); return _mm_packus_epi32(_mm256_extractf128_si256(t, 0), _mm256_extractf128_si256(t, 1)); #else @@ -1202,17 +1211,17 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { __m128i nan = _mm_set1_epi32(0x7fc0); lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask))); hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1))); - // output.value = static_cast<uint16_t>(input); + // output = numext::bit_cast<uint16_t>(input); return _mm_packus_epi32(lo, hi); #endif } template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) { - return _mm_set1_epi16(from.value); + return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from)); } template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet8bf>(const Packet8bf& from) { - return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<unsigned short>(_mm_extract_epi16(from, 0))); + return numext::bit_cast<bfloat16>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0))); } template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) { @@ -1233,17 +1242,17 @@ template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet template<> EIGEN_STRONG_INLINE Packet8bf ploaddup<Packet8bf>(const bfloat16* from) { - unsigned short a = from[0].value; - unsigned short b = from[1].value; - unsigned short c = from[2].value; - unsigned short d = from[3].value; + const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); + const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); + const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]); + const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]); return _mm_set_epi16(d, d, c, c, b, b, a, a); } template<> EIGEN_STRONG_INLINE Packet8bf ploadquad<Packet8bf>(const bfloat16* from) { - unsigned short a = from[0].value; - unsigned short b = from[1].value; + const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); + const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); return _mm_set_epi16(b, b, b, b, a, a, a, a); } @@ -1326,7 +1335,7 @@ template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; } template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) { - Packet8bf sign_mask = _mm_set1_epi16(static_cast<unsigned short>(0x8000)); + Packet8bf sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); return _mm_xor_si128(a, sign_mask); } @@ -1349,7 +1358,15 @@ template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, con template<> EIGEN_STRONG_INLINE Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride) { - return _mm_set_epi16(from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value); + const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]); + const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]); + const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]); + const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]); + const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]); + const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]); + const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]); + const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]); + return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); } template<> EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride) diff --git a/test/packetmath.cpp b/test/packetmath.cpp index feef148ad..ae21dda5c 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -1001,8 +1001,9 @@ void packetmath_scatter_gather() { int stride = internal::random<int>(1, 20); - EIGEN_ALIGN_MAX Scalar buffer[PacketSize * 20]; - memset(buffer, 0, 20 * PacketSize * sizeof(Scalar)); + // Buffer of zeros. + EIGEN_ALIGN_MAX Scalar buffer[PacketSize * 20] = {}; + Packet packet = internal::pload<Packet>(data1); internal::pscatter<Scalar, Packet>(buffer, packet, stride); @@ -1073,7 +1074,12 @@ EIGEN_DECLARE_TEST(packetmath) { CALL_SUBTEST_10(test::runner<uint64_t>::run()); CALL_SUBTEST_11(test::runner<std::complex<float> >::run()); CALL_SUBTEST_12(test::runner<std::complex<double> >::run()); +#if defined(EIGEN_VECTORIZE_AVX) + // AVX half packets not fully implemented. + CALL_SUBTEST_13((packetmath<half, internal::packet_traits<half>::type>())); +#else CALL_SUBTEST_13(test::runner<half>::run()); +#endif CALL_SUBTEST_14((packetmath<bool, internal::packet_traits<bool>::type>())); CALL_SUBTEST_15(test::runner<bfloat16>::run()); g_first_pass = false; |