diff options
author | David Tellenbach <david.tellenbach@me.com> | 2020-08-13 15:48:40 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-08-13 15:48:40 +0000 |
commit | 8ba1b0f41a7950dc3e1d4ed75859e36c73311235 (patch) | |
tree | 662ac7a45607d0527c925f22c78f0739aa9823d9 /Eigen/src/Core/arch/NEON/PacketMath.h | |
parent | 704798d1df4866be335ca013da19a44791f85a7e (diff) |
bfloat16 packetmath for Arm Neon backend
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r-- | Eigen/src/Core/arch/NEON/PacketMath.h | 263 |
1 files changed, 263 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index c2fdcbade..26d1d1298 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -3218,6 +3218,269 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) { return res; } +//---------- bfloat16 ---------- +// TODO: Add support for native armv8.6-a bfloat16_t + +// TODO: Guard if we have native bfloat16 support +typedef eigen_packet_wrapper<uint16x4_t, 19> Packet4bf; + +template<> struct is_arithmetic<Packet4bf> { enum { value = true }; }; + +template<> struct packet_traits<bfloat16> : default_packet_traits +{ + typedef Packet4bf type; + typedef Packet4bf half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasDiv = 1, + HasFloor = 1, + + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, + HasSqrt = 0, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH + }; +}; + +template<> struct unpacket_traits<Packet4bf> +{ + typedef bfloat16 type; + typedef Packet4bf half; + enum + { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p) +{ + // See the scalar implemention in BFloat16.h for a comprehensible explanation + // of this fast rounding algorithm + Packet4ui input = reinterpret_cast<Packet4ui>(p); + + // lsb = (input >> 16) & 1 + Packet4ui lsb = vandq_u32(vshrq_n_u32(input, 16), vdupq_n_u32(1)); + + // rounding_bias = 0x7fff + lsb + Packet4ui rounding_bias = vaddq_u32(lsb, vdupq_n_u32(0x7fff)); + + // input += rounding_bias + input = vaddq_u32(input, rounding_bias); + + // input = input >> 16 + input = vshrq_n_u32(input, 16); + + // Replace float-nans by bfloat16-nans, that is 0x7fc0 + const Packet4ui bf16_nan = vdupq_n_u32(0x7fc0); + const Packet4ui mask = vceqq_f32(p, p); + input = vbslq_u32(mask, input, bf16_nan); + + // output = static_cast<uint16_t>(input) + return vmovn_u32(input); +} + +EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p) +{ + return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16)); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) { + return pset1<Packet4us>(from.value); +} + +template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet4bf>(const Packet4bf& from) { + return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<uint16_t>(pfirst<Packet4us>(from))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pload<Packet4bf>(const bfloat16* from) +{ + return pload<Packet4us>(reinterpret_cast<const uint16_t*>(from)); +} + +template<> EIGEN_STRONG_INLINE Packet4bf ploadu<Packet4bf>(const bfloat16* from) +{ + return ploadu<Packet4us>(reinterpret_cast<const uint16_t*>(from)); +} + +template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet4bf& from) +{ + EIGEN_DEBUG_ALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from); +} + +template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet4bf& from) +{ + EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from); +} + +template<> EIGEN_STRONG_INLINE Packet4bf ploaddup<Packet4bf>(const bfloat16* from) +{ + return ploaddup<Packet4us>(reinterpret_cast<const uint16_t*>(from)); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pabs(const Packet4bf& a) { + return F32ToBf16(pabs<Packet4f>(Bf16ToF32(a))); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pmin<Packet4bf>(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmin<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pmax<Packet4bf>(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmax<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf por(const Packet4bf& a,const Packet4bf& b) { + return por<Packet4us>(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pxor(const Packet4bf& a,const Packet4bf& b) { + return pxor<Packet4us>(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pand(const Packet4bf& a,const Packet4bf& b) { + return pand<Packet4us>(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pandnot(const Packet4bf& a,const Packet4bf& b) { + return pandnot<Packet4us>(a, b); +} + +template<> EIGEN_DEVICE_FUNC inline Packet4bf pselect(const Packet4bf& mask, const Packet4bf& a, + const Packet4bf& b) +{ + return pselect<Packet4us>(mask, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pfloor<Packet4bf>(const Packet4bf& a) +{ + return F32ToBf16(pfloor<Packet4f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet4bf padd<Packet4bf>(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(padd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf psub<Packet4bf>(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(psub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pmul<Packet4bf>(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(pmul<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pdiv<Packet4bf>(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(pdiv<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> +EIGEN_STRONG_INLINE Packet4bf pgather<bfloat16, Packet4bf>(const bfloat16* from, Index stride) +{ + return pgather<uint16_t, Packet4us>(reinterpret_cast<const uint16_t*>(from), stride); +} + +template<> +EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet4bf>(bfloat16* to, const Packet4bf& from, Index stride) +{ + pscatter<uint16_t, Packet4us>(reinterpret_cast<uint16_t*>(to), from, stride); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet4bf>(const Packet4bf& a) +{ + return static_cast<bfloat16>(predux<Packet4f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet4bf>(const Packet4bf& a) +{ + return static_cast<bfloat16>(predux_max<Packet4f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet4bf>(const Packet4bf& a) +{ + return static_cast<bfloat16>(predux_min<Packet4f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet4bf>(const Packet4bf& a) +{ + return static_cast<bfloat16>(predux_mul<Packet4f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf preverse<Packet4bf>(const Packet4bf& a) +{ + return preverse<Packet4us>(a); +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4bf, 4>& kernel) +{ + PacketBlock<Packet4us, 4> k; + k.packet[0] = kernel.packet[0]; + k.packet[1] = kernel.packet[1]; + k.packet[2] = kernel.packet[2]; + k.packet[3] = kernel.packet[3]; + ptranspose(k); + kernel.packet[0] = k.packet[0]; + kernel.packet[1] = k.packet[1]; + kernel.packet[2] = k.packet[2]; + kernel.packet[3] = k.packet[3]; +} + +template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a, const Packet4bf& b) +{ + return F32ToBf16(pabsdiff<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b) +{ + return F32ToBf16(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b) +{ + return F32ToBf16(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b) +{ + return F32ToBf16(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a) +{ + return pxor<Packet4us>(a, pset1<Packet4us>(static_cast<uint16_t>(0x8000))); +} + //---------- double ---------- // Clang 3.5 in the iOS toolchain has an ICE triggered by NEON intrisics for double. |