diff options
author | Antonio Sanchez <cantonios@google.com> | 2021-01-26 10:23:23 -0800 |
---|---|---|
committer | Antonio Sánchez <cantonios@google.com> | 2021-01-28 18:37:09 +0000 |
commit | 1615a2799384a2964d01ba77fe98e3f6fcc412f4 (patch) | |
tree | 1901a8fc4ac9a668138979d015d3d9265eb63163 /Eigen/src/Core/arch/AltiVec | |
parent | 1414e2212c3cd36e2653bca0e11c653ece8f4d04 (diff) |
Fix altivec packetmath.
Allows the altivec packetmath tests to pass. There were a few issues:
- `pstoreu` was missing MSQ on `_BIG_ENDIAN` systems
- `cmp_*` didn't properly handle conversion of bool flags (0x7FC instead
of 0xFFFF)
- `pfrexp` needed to set the `exponent` argument.
Related to !370, #2128
cc: @ChipKerchner @pdrocaldeira
Tested on `_BIG_ENDIAN` running on QEMU with VSX. Couldn't figure out build
flags to get it to work for little endian.
Diffstat (limited to 'Eigen/src/Core/arch/AltiVec')
-rwxr-xr-x | Eigen/src/Core/arch/AltiVec/PacketMath.h | 54 |
1 files changed, 41 insertions, 13 deletions
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index 6d7842021..fdf4f1e9c 100755 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -1056,6 +1056,7 @@ template<typename Packet> EIGEN_STRONG_INLINE void pstoreu_common(__UNPACK_TYPE_ MSQ = vec_perm(edges,(Packet16uc)from,align); // misalign the data (MSQ) LSQ = vec_perm((Packet16uc)from,edges,align); // misalign the data (LSQ) vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first + vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part second #else vec_xst(from, 0, to); #endif @@ -1209,6 +1210,16 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){ ); } +// Simple interleaving of bool masks, prevents true values from being +// converted to NaNs. +EIGEN_STRONG_INLINE Packet8bf F32ToBf16Bool(Packet4f even, Packet4f odd) { + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000); + Packet4f bf_odd, bf_even; + bf_odd = pand(reinterpret_cast<Packet4f>(p4ui_high_mask), odd); + bf_even = plogical_shift_right<16>(even); + return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd)); +} + EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){ Packet4ui input = reinterpret_cast<Packet4ui>(p4f); Packet4ui lsb = plogical_shift_right<16>(input); @@ -1272,6 +1283,15 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){ Packet4f op_odd = OP(a_odd, b_odd);\ return F32ToBf16(op_even, op_odd);\ +#define BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(OP, A, B) \ + Packet4f a_even = Bf16ToF32Even(A);\ + Packet4f a_odd = Bf16ToF32Odd(A);\ + Packet4f b_even = Bf16ToF32Even(B);\ + Packet4f b_odd = Bf16ToF32Odd(B);\ + Packet4f op_even = OP(a_even, b_even);\ + Packet4f op_odd = OP(a_odd, b_odd);\ + return F32ToBf16Bool(op_even, op_odd);\ + template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { BF16_TO_F32_BINARY_OP_WRAPPER(padd<Packet4f>, a, b); } @@ -1301,12 +1321,28 @@ template<> EIGEN_STRONG_INLINE Packet8bf prsqrt<Packet8bf> (const Packet8bf& a){ template<> EIGEN_STRONG_INLINE Packet8bf pexp<Packet8bf> (const Packet8bf& a){ BF16_TO_F32_UNARY_OP_WRAPPER(pexp_float, a); } + +template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) { + return pldexp_float(a,exponent); +} template<> EIGEN_STRONG_INLINE Packet8bf pldexp<Packet8bf> (const Packet8bf& a, const Packet8bf& exponent){ BF16_TO_F32_BINARY_OP_WRAPPER(pldexp_float, a, exponent); } -template<> EIGEN_STRONG_INLINE Packet8bf pfrexp<Packet8bf> (const Packet8bf& a, Packet8bf& exponent){ - BF16_TO_F32_BINARY_OP_WRAPPER(pfrexp_float, a, exponent); + +template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) { + return pfrexp_float(a,exponent); +} +template<> EIGEN_STRONG_INLINE Packet8bf pfrexp<Packet8bf> (const Packet8bf& a, Packet8bf& e){ + Packet4f a_even = Bf16ToF32Even(a); + Packet4f a_odd = Bf16ToF32Odd(a); + Packet4f e_even; + Packet4f e_odd; + Packet4f op_even = pfrexp<Packet4f>(a_even, e_even); + Packet4f op_odd = pfrexp<Packet4f>(a_odd, e_odd); + e = F32ToBf16(e_even, e_odd); + return F32ToBf16(op_even, op_odd); } + template<> EIGEN_STRONG_INLINE Packet8bf psin<Packet8bf> (const Packet8bf& a){ BF16_TO_F32_UNARY_OP_WRAPPER(psin_float, a); } @@ -1346,13 +1382,13 @@ template<> EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a, con } template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) { - BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_lt<Packet4f>, a, b); + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_lt<Packet4f>, a, b); } template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) { - BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_le<Packet4f>, a, b); + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_le<Packet4f>, a, b); } template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) { - BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_eq<Packet4f>, a, b); + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_eq<Packet4f>, a, b); } template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) { @@ -1370,14 +1406,6 @@ template<> EIGEN_STRONG_INLINE Packet8bf plset<Packet8bf>(const bfloat16& a) { return padd<Packet8bf>(pset1<Packet8bf>(a), pload<Packet8bf>(countdown)); } -template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) { - return pfrexp_float(a,exponent); -} - -template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) { - return pldexp_float(a,exponent); -} - template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a) { Packet4f b, sum; |