aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AltiVec
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-01-26 10:23:23 -0800
committerGravatar Antonio Sánchez <cantonios@google.com>2021-01-28 18:37:09 +0000
commit1615a2799384a2964d01ba77fe98e3f6fcc412f4 (patch)
tree1901a8fc4ac9a668138979d015d3d9265eb63163 /Eigen/src/Core/arch/AltiVec
parent1414e2212c3cd36e2653bca0e11c653ece8f4d04 (diff)
Fix altivec packetmath.
Allows the altivec packetmath tests to pass. There were a few issues: - `pstoreu` was missing MSQ on `_BIG_ENDIAN` systems - `cmp_*` didn't properly handle conversion of bool flags (0x7FC instead of 0xFFFF) - `pfrexp` needed to set the `exponent` argument. Related to !370, #2128 cc: @ChipKerchner @pdrocaldeira Tested on `_BIG_ENDIAN` running on QEMU with VSX. Couldn't figure out build flags to get it to work for little endian.
Diffstat (limited to 'Eigen/src/Core/arch/AltiVec')
-rwxr-xr-xEigen/src/Core/arch/AltiVec/PacketMath.h54
1 files changed, 41 insertions, 13 deletions
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h
index 6d7842021..fdf4f1e9c 100755
--- a/Eigen/src/Core/arch/AltiVec/PacketMath.h
+++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -1056,6 +1056,7 @@ template<typename Packet> EIGEN_STRONG_INLINE void pstoreu_common(__UNPACK_TYPE_
MSQ = vec_perm(edges,(Packet16uc)from,align); // misalign the data (MSQ)
LSQ = vec_perm((Packet16uc)from,edges,align); // misalign the data (LSQ)
vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first
+ vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part second
#else
vec_xst(from, 0, to);
#endif
@@ -1209,6 +1210,16 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){
);
}
+// Simple interleaving of bool masks, prevents true values from being
+// converted to NaNs.
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16Bool(Packet4f even, Packet4f odd) {
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000);
+ Packet4f bf_odd, bf_even;
+ bf_odd = pand(reinterpret_cast<Packet4f>(p4ui_high_mask), odd);
+ bf_even = plogical_shift_right<16>(even);
+ return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
+}
+
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){
Packet4ui input = reinterpret_cast<Packet4ui>(p4f);
Packet4ui lsb = plogical_shift_right<16>(input);
@@ -1272,6 +1283,15 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){
Packet4f op_odd = OP(a_odd, b_odd);\
return F32ToBf16(op_even, op_odd);\
+#define BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(OP, A, B) \
+ Packet4f a_even = Bf16ToF32Even(A);\
+ Packet4f a_odd = Bf16ToF32Odd(A);\
+ Packet4f b_even = Bf16ToF32Even(B);\
+ Packet4f b_odd = Bf16ToF32Odd(B);\
+ Packet4f op_even = OP(a_even, b_even);\
+ Packet4f op_odd = OP(a_odd, b_odd);\
+ return F32ToBf16Bool(op_even, op_odd);\
+
template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(padd<Packet4f>, a, b);
}
@@ -1301,12 +1321,28 @@ template<> EIGEN_STRONG_INLINE Packet8bf prsqrt<Packet8bf> (const Packet8bf& a){
template<> EIGEN_STRONG_INLINE Packet8bf pexp<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(pexp_float, a);
}
+
+template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
+ return pldexp_float(a,exponent);
+}
template<> EIGEN_STRONG_INLINE Packet8bf pldexp<Packet8bf> (const Packet8bf& a, const Packet8bf& exponent){
BF16_TO_F32_BINARY_OP_WRAPPER(pldexp_float, a, exponent);
}
-template<> EIGEN_STRONG_INLINE Packet8bf pfrexp<Packet8bf> (const Packet8bf& a, Packet8bf& exponent){
- BF16_TO_F32_BINARY_OP_WRAPPER(pfrexp_float, a, exponent);
+
+template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
+ return pfrexp_float(a,exponent);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pfrexp<Packet8bf> (const Packet8bf& a, Packet8bf& e){
+ Packet4f a_even = Bf16ToF32Even(a);
+ Packet4f a_odd = Bf16ToF32Odd(a);
+ Packet4f e_even;
+ Packet4f e_odd;
+ Packet4f op_even = pfrexp<Packet4f>(a_even, e_even);
+ Packet4f op_odd = pfrexp<Packet4f>(a_odd, e_odd);
+ e = F32ToBf16(e_even, e_odd);
+ return F32ToBf16(op_even, op_odd);
}
+
template<> EIGEN_STRONG_INLINE Packet8bf psin<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(psin_float, a);
}
@@ -1346,13 +1382,13 @@ template<> EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a, con
}
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) {
- BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_lt<Packet4f>, a, b);
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_lt<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) {
- BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_le<Packet4f>, a, b);
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_le<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) {
- BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_eq<Packet4f>, a, b);
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_eq<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) {
@@ -1370,14 +1406,6 @@ template<> EIGEN_STRONG_INLINE Packet8bf plset<Packet8bf>(const bfloat16& a) {
return padd<Packet8bf>(pset1<Packet8bf>(a), pload<Packet8bf>(countdown));
}
-template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
- return pfrexp_float(a,exponent);
-}
-
-template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
- return pldexp_float(a,exponent);
-}
-
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
{
Packet4f b, sum;