diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-11-19 15:44:19 -0800 |
---|---|---|
committer | Antonio Sanchez <cantonios@google.com> | 2020-11-21 09:05:10 -0800 |
commit | 4cf01d2cf5e10c38fdec01acd335b11b924de399 (patch) | |
tree | 91e1d0f8dd66d1ec7fb3dfc2f58bc7e928a27e4f /Eigen/src/Core/arch/AVX | |
parent | fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed (diff) |
Update AVX half packets, disable test.
The AVX half implementation is incomplete, causing the `packetmath_13` test
to fail. This disables the test.
Also refactored the existing AVX implementation to use `bit_cast`
instead of direct access to `.x`.
Diffstat (limited to 'Eigen/src/Core/arch/AVX')
-rw-r--r-- | Eigen/src/Core/arch/AVX/PacketMath.h | 81 |
1 files changed, 49 insertions, 32 deletions
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index ae111c671..b68351356 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -870,14 +870,16 @@ template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, cons } // Packet math for Eigen::half +// TODO(cantonios): add missing packet ops +// - pabs, pmin, pmax, plset, pround, print, pceil, pfloor, pcmp_lt, pcmp_le, pcmp_lt_or_nan template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; }; template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) { - return _mm_set1_epi16(from.x); + return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from)); } template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) { - return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm_extract_epi16(from, 0))); + return numext::bit_cast<Eigen::half>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0))); } template<> EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) { @@ -898,17 +900,17 @@ template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const template<> EIGEN_STRONG_INLINE Packet8h ploaddup<Packet8h>(const Eigen::half* from) { - unsigned short a = from[0].x; - unsigned short b = from[1].x; - unsigned short c = from[2].x; - unsigned short d = from[3].x; + const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); + const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); + const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]); + const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]); return _mm_set_epi16(d, d, c, c, b, b, a, a); } template<> EIGEN_STRONG_INLINE Packet8h ploadquad<Packet8h>(const Eigen::half* from) { - unsigned short a = from[0].x; - unsigned short b = from[1].x; + const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); + const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); return _mm_set_epi16(b, b, b, b, a, a, a, a); } @@ -937,16 +939,15 @@ EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { #else EIGEN_ALIGN32 float aux[8]; pstore(aux, a); - Eigen::half h0(aux[0]); - Eigen::half h1(aux[1]); - Eigen::half h2(aux[2]); - Eigen::half h3(aux[3]); - Eigen::half h4(aux[4]); - Eigen::half h5(aux[5]); - Eigen::half h6(aux[6]); - Eigen::half h7(aux[7]); - - return _mm_set_epi16(h7.x, h6.x, h5.x, h4.x, h3.x, h2.x, h1.x, h0.x); + const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[0])); + const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[1])); + const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[2])); + const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[3])); + const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[4])); + const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[5])); + const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[6])); + const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[7])); + return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); #endif } @@ -985,7 +986,7 @@ template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; } template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) { - Packet8h sign_mask = _mm_set1_epi16(static_cast<unsigned short>(0x8000)); + Packet8h sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); return _mm_xor_si128(a, sign_mask); } @@ -1019,7 +1020,15 @@ template<> EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const template<> EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride) { - return _mm_set_epi16(from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x, from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x); + const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]); + const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]); + const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]); + const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]); + const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]); + const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]); + const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]); + const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]); + return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); } template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride) @@ -1178,7 +1187,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q); __m256i nan = _mm256_set1_epi32(0x7fc0); t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask)); - // output.value = static_cast<uint16_t>(input); + // output = numext::bit_cast<uint16_t>(input); return _mm_packus_epi32(_mm256_extractf128_si256(t, 0), _mm256_extractf128_si256(t, 1)); #else @@ -1202,17 +1211,17 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { __m128i nan = _mm_set1_epi32(0x7fc0); lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask))); hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1))); - // output.value = static_cast<uint16_t>(input); + // output = numext::bit_cast<uint16_t>(input); return _mm_packus_epi32(lo, hi); #endif } template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) { - return _mm_set1_epi16(from.value); + return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from)); } template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet8bf>(const Packet8bf& from) { - return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<unsigned short>(_mm_extract_epi16(from, 0))); + return numext::bit_cast<bfloat16>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0))); } template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) { @@ -1233,17 +1242,17 @@ template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet template<> EIGEN_STRONG_INLINE Packet8bf ploaddup<Packet8bf>(const bfloat16* from) { - unsigned short a = from[0].value; - unsigned short b = from[1].value; - unsigned short c = from[2].value; - unsigned short d = from[3].value; + const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); + const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); + const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]); + const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]); return _mm_set_epi16(d, d, c, c, b, b, a, a); } template<> EIGEN_STRONG_INLINE Packet8bf ploadquad<Packet8bf>(const bfloat16* from) { - unsigned short a = from[0].value; - unsigned short b = from[1].value; + const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); + const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); return _mm_set_epi16(b, b, b, b, a, a, a, a); } @@ -1326,7 +1335,7 @@ template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; } template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) { - Packet8bf sign_mask = _mm_set1_epi16(static_cast<unsigned short>(0x8000)); + Packet8bf sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); return _mm_xor_si128(a, sign_mask); } @@ -1349,7 +1358,15 @@ template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, con template<> EIGEN_STRONG_INLINE Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride) { - return _mm_set_epi16(from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value); + const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]); + const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]); + const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]); + const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]); + const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]); + const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]); + const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]); + const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]); + return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); } template<> EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride) |