diff options
-rwxr-xr-x | Eigen/src/Core/arch/AltiVec/PacketMath.h | 431 | ||||
-rw-r--r-- | test/packetmath.cpp | 31 |
2 files changed, 457 insertions, 5 deletions
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index e0da7c329..09ad0a74e 100755 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -39,10 +39,10 @@ typedef __vector short int Packet8s; typedef __vector unsigned short int Packet8us; typedef __vector int8_t Packet16c; typedef __vector uint8_t Packet16uc; +typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf; // We don't want to write the same code all the time, but we need to reuse the constants // and it doesn't really work to declare them global, so we define macros instead - #define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \ Packet4f p4f_##NAME = {X, X, X, X} @@ -96,6 +96,7 @@ static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 }; static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 }; static Packet8s p8s_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 }; static Packet8us p8us_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 }; + static Packet16c p16c_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; static Packet16uc p16uc_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7, @@ -108,6 +109,8 @@ static Packet16uc p16uc_REVERSE8 = { 15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 }; static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 }; static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 }; static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 }; +static const Packet16uc p16uc_DUPLICATE16_EVEN= { 0,1 ,0,1, 4,5, 4,5, 8,9, 8,9, 12,13, 12,13 }; +static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,11, 14,15, 14,15 }; static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 }; @@ -190,6 +193,48 @@ struct packet_traits<float> : default_packet_traits { }; }; template <> +struct packet_traits<bfloat16> : default_packet_traits { + typedef Packet8bf type; + typedef Packet8bf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasMin = 1, + HasMax = 1, + HasAbs = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, +#ifdef __VSX__ + HasSqrt = 1, +#if !EIGEN_COMP_CLANG + HasRsqrt = 1, +#else + HasRsqrt = 0, +#endif +#else + HasSqrt = 0, + HasRsqrt = 0, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, +#endif + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasNegate = 1, + HasBlend = 1 + }; +}; + +template <> struct packet_traits<int> : default_packet_traits { typedef Packet4i type; typedef Packet4i half; @@ -319,6 +364,12 @@ template<> struct unpacket_traits<Packet16uc> enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; +template<> struct unpacket_traits<Packet8bf> +{ + typedef bfloat16 type; + typedef Packet8bf half; + enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; inline std::ostream & operator <<(std::ostream & s, const Packet16c & v) { union { @@ -421,6 +472,11 @@ template<> EIGEN_STRONG_INLINE Packet16uc pload<Packet16uc>(const uint8_t* f return pload_common<Packet16uc>(from); } +template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) +{ + return pload_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from)); +} + template <typename Packet> EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet& from){ // some versions of GCC throw "unused-but-set-parameter" (float *to). @@ -431,7 +487,7 @@ EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet vec_xst(from, 0, to); #else vec_st(from, 0, to); -#endif +#endif } template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from) @@ -454,6 +510,11 @@ template<> EIGEN_STRONG_INLINE void pstore<unsigned short int>(unsigned short in pstore_common<Packet8us>(to, from); } +template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from) +{ + pstore_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from); +} + template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet16c& from) { pstore_common<Packet16c>(to, from); @@ -513,6 +574,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int fro return reinterpret_cast<Packet4f>(pset1<Packet4i>(from)); } +template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) { + return pset1_size8<Packet8us>(reinterpret_cast<const unsigned short int&>(from)); +} + template<typename Packet> EIGEN_STRONG_INLINE void pbroadcast4_common(const __UNPACK_TYPE__(Packet) *a, Packet& a0, Packet& a1, Packet& a2, Packet& a3) @@ -700,6 +765,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc plset<Packet16uc>(const uint8_t& a) template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f> (const Packet4f& a, const Packet4f& b) { return a + b; } template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i> (const Packet4i& a, const Packet4i& b) { return a + b; } +template<> EIGEN_STRONG_INLINE Packet4ui padd<Packet4ui> (const Packet4ui& a, const Packet4ui& b) { return a + b; } template<> EIGEN_STRONG_INLINE Packet8s padd<Packet8s> (const Packet8s& a, const Packet8s& b) { return a + b; } template<> EIGEN_STRONG_INLINE Packet8us padd<Packet8us> (const Packet8us& a, const Packet8us& b) { return a + b; } template<> EIGEN_STRONG_INLINE Packet16c padd<Packet16c> (const Packet16c& a, const Packet16c& b) { return a + b; } @@ -721,6 +787,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i> (const Packet4i& a, template<> EIGEN_STRONG_INLINE Packet16c pmul<Packet16c> (const Packet16c& a, const Packet16c& b) { return vec_mul(a,b); } template<> EIGEN_STRONG_INLINE Packet16uc pmul<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_mul(a,b); } + template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b) { #ifndef __VSX__ // VSX actually provides a div instruction @@ -765,6 +832,7 @@ template<> EIGEN_STRONG_INLINE Packet8us pmin<Packet8us>(const Packet8us& a, con template<> EIGEN_STRONG_INLINE Packet16c pmin<Packet16c>(const Packet16c& a, const Packet16c& b) { return vec_min(a, b); } template<> EIGEN_STRONG_INLINE Packet16uc pmin<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_min(a, b); } + template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) { #ifdef __VSX__ @@ -785,6 +853,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc pmax<Packet16uc>(const Packet16uc& a, template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmple(a,b)); } template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmplt(a,b)); } template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmpeq(a,b)); } + template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { Packet4f c = reinterpret_cast<Packet4f>(vec_cmpge(a,b)); return vec_nor(c,c); @@ -793,12 +862,26 @@ template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4 template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, b); } template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet4ui pand<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet8us pand<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet8bf pand<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + return pand<Packet8us>(a, b); +} + template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_or(a, b); } template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet8s por<Packet8s>(const Packet8s& a, const Packet8s& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet8us por<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet8bf por<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + return por<Packet8us>(a, b); +} template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); } template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); } +template<> EIGEN_STRONG_INLINE Packet8bf pxor<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + return pxor<Packet8us>(a, b); +} template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, vec_nor(b, b)); } template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, vec_nor(b, b)); } @@ -806,6 +889,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, con template<> EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) { return vec_sel(b, a, reinterpret_cast<Packet4ui>(mask)); } + template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) { Packet4f t = vec_add(reinterpret_cast<Packet4f>(vec_or(vec_and(reinterpret_cast<Packet4ui>(a), p4ui_SIGN), p4ui_PREV0DOT5)), a); Packet4f res; @@ -852,6 +936,10 @@ template<> EIGEN_STRONG_INLINE Packet8us ploadu<Packet8us>(const unsigned short { return ploadu_common<Packet8us>(from); } +template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from) +{ + return ploadu_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from)); +} template<> EIGEN_STRONG_INLINE Packet16c ploadu<Packet16c>(const int8_t* from) { return ploadu_common<Packet16c>(from); @@ -909,6 +997,11 @@ template<> EIGEN_STRONG_INLINE Packet8us ploadquad<Packet8us>(const unsigned sho return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI); } +template<> EIGEN_STRONG_INLINE Packet8bf ploadquad<Packet8bf>(const bfloat16* from) +{ + return ploadquad<Packet8us>(reinterpret_cast<const unsigned short int*>(from)); +} + template<> EIGEN_STRONG_INLINE Packet16c ploaddup<Packet16c>(const int8_t* from) { Packet16c p; @@ -962,6 +1055,10 @@ template<> EIGEN_STRONG_INLINE void pstoreu<unsigned short int>(unsigned short i { pstoreu_common<Packet8us>(to, from); } +template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from) +{ + pstoreu_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from); +} template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet16c& from) { pstoreu_common<Packet16c>(to, from); @@ -977,17 +1074,17 @@ template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { EIGE template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { EIGEN_ALIGN16 float x; vec_ste(a, 0, &x); return x; } template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { EIGEN_ALIGN16 int x; vec_ste(a, 0, &x); return x; } -template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet& a) { +template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet& a) { EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) x; vec_ste(a, 0, &x); return x; } -template<> EIGEN_STRONG_INLINE short int pfirst<Packet8s>(const Packet8s& a) { +template<> EIGEN_STRONG_INLINE short int pfirst<Packet8s>(const Packet8s& a) { return pfirst_common<Packet8s>(a); } -template<> EIGEN_STRONG_INLINE unsigned short int pfirst<Packet8us>(const Packet8us& a) { +template<> EIGEN_STRONG_INLINE unsigned short int pfirst<Packet8us>(const Packet8us& a) { return pfirst_common<Packet8us>(a); } @@ -1025,6 +1122,10 @@ template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a) { return vec_perm(a, a, p16uc_REVERSE8); } +template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a) +{ + return preverse<Packet8us>(a); +} template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vec_abs(a); } template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vec_abs(a); } @@ -1032,6 +1133,10 @@ template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vec_abs template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; } template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vec_abs(a); } template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) { + _EIGEN_DECLARE_CONST_FAST_Packet8us(abs_mask,0x7FFF); + return pand<Packet8us>(p8us_abs_mask, a); +} template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a) { return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); } @@ -1039,6 +1144,175 @@ template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a) { return vec_sr(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); } template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a) { return vec_sl(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); } +template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_left(Packet4f a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + Packet4ui r = vec_sl(reinterpret_cast<Packet4ui>(a), p4ui_mask); + return reinterpret_cast<Packet4f>(r); +} + +template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_right(Packet4f a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + Packet4ui r = vec_sr(reinterpret_cast<Packet4ui>(a), p4ui_mask); + return reinterpret_cast<Packet4f>(r); +} + +template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + return vec_sr(a, p4ui_mask); +} + +template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + return vec_sl(a, p4ui_mask); +} + +template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N); + return vec_sl(a, p8us_mask); +} + +EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){ + return plogical_shift_left<16>(reinterpret_cast<Packet4f>(bf.m_val)); +} + +EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000); + return pand<Packet4f>( + reinterpret_cast<Packet4f>(bf.m_val), + reinterpret_cast<Packet4f>(p4ui_high_mask) + ); +} + +EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){ + Packet4ui input = reinterpret_cast<Packet4ui>(p4f); + Packet4ui lsb = plogical_shift_right<16>(input); + lsb = pand<Packet4ui>(lsb, reinterpret_cast<Packet4ui>(p4i_ONE)); + + _EIGEN_DECLARE_CONST_FAST_Packet4ui(BIAS,0x7FFFu); + Packet4ui rounding_bias = padd<Packet4ui>(lsb, p4ui_BIAS); + input = padd<Packet4ui>(input, rounding_bias); + + //Test NaN and Subnormal - Begin + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000); + Packet4ui exp = pand<Packet4ui>(p4ui_exp_mask, reinterpret_cast<Packet4ui>(p4f)); + + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF); + Packet4ui mantissa = pand<Packet4ui>(p4ui_mantissa_mask, reinterpret_cast<Packet4ui>(p4f)); + + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000); + Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp); + Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet4ui>(p4i_ZERO)); + + Packet4bi is_mant_not_zero = vec_cmpne(mantissa, reinterpret_cast<Packet4ui>(p4i_ZERO)); + Packet4ui nan_selector = pand<Packet4ui>( + reinterpret_cast<Packet4ui>(is_max_exp), + reinterpret_cast<Packet4ui>(is_mant_not_zero) + ); + + Packet4ui subnormal_selector = pand<Packet4ui>( + reinterpret_cast<Packet4ui>(is_zero_exp), + reinterpret_cast<Packet4ui>(is_mant_not_zero) + ); + + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000); + input = vec_sel(input, p4ui_nan, nan_selector); + input = vec_sel(input, reinterpret_cast<Packet4ui>(p4f), subnormal_selector); + //Test NaN and Subnormal - End + + input = plogical_shift_right<16>(input); + return reinterpret_cast<Packet8us>(input); +} + +EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){ + Packet4f bf_odd, bf_even; + bf_odd = reinterpret_cast<Packet4f>(F32ToBf16(odd).m_val); + bf_odd = plogical_shift_left<16>(bf_odd); + bf_even = reinterpret_cast<Packet4f>(F32ToBf16(even).m_val); + return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd)); +} +#define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \ + Packet4f a_even = Bf16ToF32Even(A);\ + Packet4f a_odd = Bf16ToF32Odd(A);\ + Packet4f op_even = OP(a_even);\ + Packet4f op_odd = OP(a_odd);\ + return F32ToBf16(op_even, op_odd);\ + +#define BF16_TO_F32_BINARY_OP_WRAPPER(OP, A, B) \ + Packet4f a_even = Bf16ToF32Even(A);\ + Packet4f a_odd = Bf16ToF32Odd(A);\ + Packet4f b_even = Bf16ToF32Even(B);\ + Packet4f b_odd = Bf16ToF32Odd(B);\ + Packet4f op_even = OP(a_even, b_even);\ + Packet4f op_odd = OP(a_odd, b_odd);\ + return F32ToBf16(op_even, op_odd);\ + +template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(padd<Packet4f>, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pmul<Packet4f>, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pdiv<Packet4f>, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pnegate<Packet8bf>(const Packet8bf& a) { + BF16_TO_F32_UNARY_OP_WRAPPER(pnegate<Packet4f>, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(psub<Packet4f>, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psqrt<Packet8bf> (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(vec_sqrt, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) { + Packet4f a_even = Bf16ToF32Even(a); + Packet4f a_odd = Bf16ToF32Odd(a); + Packet4f b_even = Bf16ToF32Even(b); + Packet4f b_odd = Bf16ToF32Odd(b); + Packet4f c_even = Bf16ToF32Even(c); + Packet4f c_odd = Bf16ToF32Odd(c); + Packet4f pmadd_even = pmadd<Packet4f>(a_even, b_even, c_even); + Packet4f pmadd_odd = pmadd<Packet4f>(a_odd, b_odd, c_odd); + return F32ToBf16(pmadd_even, pmadd_odd); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pmin<Packet4f>, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pmax<Packet4f>, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_lt<Packet4f>, a, b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_le<Packet4f>, a, b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_eq<Packet4f>, a, b); +} + +template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) { + return Eigen::bfloat16_impl::raw_uint16_to_bfloat16((pfirst<Packet8us>(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ploaddup<Packet8bf>(const bfloat16* from) +{ + return ploaddup<Packet8us>(reinterpret_cast<const unsigned short int*>(from)); +} template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) { return pfrexp_float(a,exponent); @@ -1070,6 +1344,13 @@ template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a) return pfirst(sum); } +template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a) +{ + float redux_even = predux<Packet4f>(Bf16ToF32Even(a)); + float redux_odd = predux<Packet4f>(Bf16ToF32Odd(a)); + float f32_result = redux_even + redux_odd; + return bfloat16(f32_result); +} template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size8(const Packet& a) { union{ @@ -1166,6 +1447,15 @@ template<> EIGEN_STRONG_INLINE unsigned short int predux_mul<Packet8us>(const Pa return pfirst(octo); } +template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a) +{ + float redux_even = predux_mul<Packet4f>(Bf16ToF32Even(a)); + float redux_odd = predux_mul<Packet4f>(Bf16ToF32Odd(a)); + float f32_result = redux_even * redux_odd; + return bfloat16(f32_result); +} + + template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet16c>(const Packet16c& a) { Packet16c pair, quad, octo, result; @@ -1211,6 +1501,14 @@ template<> EIGEN_STRONG_INLINE int predux_min<Packet4i>(const Packet4i& a) return predux_min4<Packet4i>(a); } +template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a) +{ + float redux_even = predux_min<Packet4f>(Bf16ToF32Even(a)); + float redux_odd = predux_min<Packet4f>(Bf16ToF32Odd(a)); + float f32_result = (std::min)(redux_even, redux_odd); + return bfloat16(f32_result); +} + template<> EIGEN_STRONG_INLINE short int predux_min<Packet8s>(const Packet8s& a) { Packet8s pair, quad, octo; @@ -1283,6 +1581,14 @@ template<> EIGEN_STRONG_INLINE int predux_max<Packet4i>(const Packet4i& a) return predux_max4<Packet4i>(a); } +template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a) +{ + float redux_even = predux_max<Packet4f>(Bf16ToF32Even(a)); + float redux_odd = predux_max<Packet4f>(Bf16ToF32Odd(a)); + float f32_result = (std::max)(redux_even, redux_odd); + return bfloat16(f32_result); +} + template<> EIGEN_STRONG_INLINE short int predux_max<Packet8s>(const Packet8s& a) { Packet8s pair, quad, octo; @@ -1391,6 +1697,21 @@ ptranspose(PacketBlock<Packet8us,4>& kernel) { kernel.packet[3] = vec_mergel(t1, t3); } + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock<Packet8bf,4>& kernel) { + Packet8us t0, t1, t2, t3; + + t0 = vec_mergeh(kernel.packet[0].m_val, kernel.packet[2].m_val); + t1 = vec_mergel(kernel.packet[0].m_val, kernel.packet[2].m_val); + t2 = vec_mergeh(kernel.packet[1].m_val, kernel.packet[3].m_val); + t3 = vec_mergel(kernel.packet[1].m_val, kernel.packet[3].m_val); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16c,4>& kernel) { Packet16c t0, t1, t2, t3; @@ -1481,6 +1802,37 @@ ptranspose(PacketBlock<Packet8us,8>& kernel) { } EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock<Packet8bf,8>& kernel) { + Packet8bf v[8], sum[8]; + + v[0] = vec_mergeh(kernel.packet[0].m_val, kernel.packet[4].m_val); + v[1] = vec_mergel(kernel.packet[0].m_val, kernel.packet[4].m_val); + v[2] = vec_mergeh(kernel.packet[1].m_val, kernel.packet[5].m_val); + v[3] = vec_mergel(kernel.packet[1].m_val, kernel.packet[5].m_val); + v[4] = vec_mergeh(kernel.packet[2].m_val, kernel.packet[6].m_val); + v[5] = vec_mergel(kernel.packet[2].m_val, kernel.packet[6].m_val); + v[6] = vec_mergeh(kernel.packet[3].m_val, kernel.packet[7].m_val); + v[7] = vec_mergel(kernel.packet[3].m_val, kernel.packet[7].m_val); + sum[0] = vec_mergeh(v[0].m_val, v[4].m_val); + sum[1] = vec_mergel(v[0].m_val, v[4].m_val); + sum[2] = vec_mergeh(v[1].m_val, v[5].m_val); + sum[3] = vec_mergel(v[1].m_val, v[5].m_val); + sum[4] = vec_mergeh(v[2].m_val, v[6].m_val); + sum[5] = vec_mergel(v[2].m_val, v[6].m_val); + sum[6] = vec_mergeh(v[3].m_val, v[7].m_val); + sum[7] = vec_mergel(v[3].m_val, v[7].m_val); + + kernel.packet[0] = vec_mergeh(sum[0].m_val, sum[4].m_val); + kernel.packet[1] = vec_mergel(sum[0].m_val, sum[4].m_val); + kernel.packet[2] = vec_mergeh(sum[1].m_val, sum[5].m_val); + kernel.packet[3] = vec_mergel(sum[1].m_val, sum[5].m_val); + kernel.packet[4] = vec_mergeh(sum[2].m_val, sum[6].m_val); + kernel.packet[5] = vec_mergel(sum[2].m_val, sum[6].m_val); + kernel.packet[6] = vec_mergeh(sum[3].m_val, sum[7].m_val); + kernel.packet[7] = vec_mergel(sum[3].m_val, sum[7].m_val); +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16c,16>& kernel) { Packet16c step1[16], step2[16], step3[16]; @@ -1656,6 +2008,10 @@ template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, con return vec_sel(elsePacket, thenPacket, mask); } +template<> EIGEN_STRONG_INLINE Packet8bf pblend(const Selector<8>& ifPacket, const Packet8bf& thenPacket, const Packet8bf& elsePacket) { + return pblend<Packet8us>(ifPacket, thenPacket, elsePacket); +} + template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, const Packet16c& thenPacket, const Packet16c& elsePacket) { Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7], @@ -1694,15 +2050,78 @@ struct type_casting_traits<int, float> { }; }; +template <> +struct type_casting_traits<bfloat16, unsigned short int> { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits<unsigned short int, bfloat16> { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) { return vec_cts(a,0); } +template<> EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) { + return vec_ctu(a,0); +} + template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) { return vec_ctf(a,0); } +template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) { + return vec_ctf(a,0); +} + +template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packet8bf& a) { + Packet4f float_even = Bf16ToF32Even(a); + Packet4f float_odd = Bf16ToF32Odd(a); + Packet4ui int_even = pcast<Packet4f, Packet4ui>(float_even); + Packet4ui int_odd = pcast<Packet4f, Packet4ui>(float_odd); + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF); + Packet4ui low_even = pand<Packet4ui>(int_even, p4ui_low_mask); + Packet4ui low_odd = pand<Packet4ui>(int_odd, p4ui_low_mask); + + //Check values that are bigger than USHRT_MAX (0xFFFF) + Packet4bi overflow_selector; + if(vec_any_gt(int_even, p4ui_low_mask)){ + overflow_selector = vec_cmpgt(int_even, p4ui_low_mask); + low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector); + } + if(vec_any_gt(int_odd, p4ui_low_mask)){ + overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask); + low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector); + } + + low_odd = plogical_shift_left<16>(low_odd); + + Packet4ui int_final = por<Packet4ui>(low_even, low_odd); + return reinterpret_cast<Packet8us>(int_final); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) { + //short -> int -> float -> bfloat16 + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF); + Packet4ui int_cast = reinterpret_cast<Packet4ui>(a); + Packet4ui int_even = pand<Packet4ui>(int_cast, p4ui_low_mask); + Packet4ui int_odd = plogical_shift_right<16>(int_cast); + Packet4f float_even = pcast<Packet4ui, Packet4f>(int_even); + Packet4f float_odd = pcast<Packet4ui, Packet4f>(int_odd); + return F32ToBf16(float_even, float_odd); +} + + template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) { return reinterpret_cast<Packet4i>(a); } @@ -2024,6 +2443,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, cons Packet2bl mask = reinterpret_cast<Packet2bl>( vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE)) ); return vec_sel(elsePacket, thenPacket, mask); } + + #endif // __VSX__ } // end namespace internal diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 0fe29102a..c8ea3139e 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -247,6 +247,20 @@ void packetmath_boolean_mask_ops() { data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0); } CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); + + //Test (-0) == (0) for signed operations + for (int i = 0; i < PacketSize; ++i) { + data1[i] = Scalar(-0.0); + data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0); + } + CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); + + //Test NaN + for (int i = 0; i < PacketSize; ++i) { + data1[i] = std::numeric_limits<Scalar>::quiet_NaN(); + data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0); + } + CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); } // Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path @@ -255,6 +269,22 @@ template<> void packetmath_boolean_mask_ops<bool, internal::packet_traits<bool>::type>() {} template <typename Scalar, typename Packet> +void packetmath_minus_zero_add() { + const int PacketSize = internal::unpacket_traits<Packet>::size; + const int size = 2 * PacketSize; + EIGEN_ALIGN_MAX Scalar data1[size]; + EIGEN_ALIGN_MAX Scalar data2[size]; + EIGEN_ALIGN_MAX Scalar ref[size]; + + for (int i = 0; i < PacketSize; ++i) { + data1[i] = Scalar(-0.0); + data1[i + PacketSize] = Scalar(-0.0); + } + CHECK_CWISE2_IF(internal::packet_traits<Scalar>::HasAdd, REF_ADD, internal::padd); +} + + +template <typename Scalar, typename Packet> void packetmath() { typedef internal::packet_traits<Scalar> PacketTraits; const int PacketSize = internal::unpacket_traits<Packet>::size; @@ -454,6 +484,7 @@ void packetmath() { packetmath_boolean_mask_ops<Scalar, Packet>(); packetmath_pcast_ops_runner<Scalar, Packet>::run(); + packetmath_minus_zero_add<Scalar, Packet>(); } template <typename Scalar, typename Packet> |