From 8ba1b0f41a7950dc3e1d4ed75859e36c73311235 Mon Sep 17 00:00:00 2001 From: David Tellenbach Date: Thu, 13 Aug 2020 15:48:40 +0000 Subject: bfloat16 packetmath for Arm Neon backend --- Eigen/src/Core/arch/NEON/MathFunctions.h | 6 + Eigen/src/Core/arch/NEON/PacketMath.h | 263 +++++++++++++++++++++++++++++++ 2 files changed, 269 insertions(+) (limited to 'Eigen/src/Core') diff --git a/Eigen/src/Core/arch/NEON/MathFunctions.h b/Eigen/src/Core/arch/NEON/MathFunctions.h index 1c025618e..ab549a683 100644 --- a/Eigen/src/Core/arch/NEON/MathFunctions.h +++ b/Eigen/src/Core/arch/NEON/MathFunctions.h @@ -38,6 +38,12 @@ template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Pack template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f ptanh(const Packet4f& x) { return internal::generic_fast_tanh_float(x); } +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin) +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos) +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog) +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp) +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh) + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index c2fdcbade..26d1d1298 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -3218,6 +3218,269 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) { return res; } +//---------- bfloat16 ---------- +// TODO: Add support for native armv8.6-a bfloat16_t + +// TODO: Guard if we have native bfloat16 support +typedef eigen_packet_wrapper Packet4bf; + +template<> struct is_arithmetic { enum { value = true }; }; + +template<> struct packet_traits : default_packet_traits +{ + typedef Packet4bf type; + typedef Packet4bf half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasDiv = 1, + HasFloor = 1, + + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, + HasSqrt = 0, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH + }; +}; + +template<> struct unpacket_traits +{ + typedef bfloat16 type; + typedef Packet4bf half; + enum + { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p) +{ + // See the scalar implemention in BFloat16.h for a comprehensible explanation + // of this fast rounding algorithm + Packet4ui input = reinterpret_cast(p); + + // lsb = (input >> 16) & 1 + Packet4ui lsb = vandq_u32(vshrq_n_u32(input, 16), vdupq_n_u32(1)); + + // rounding_bias = 0x7fff + lsb + Packet4ui rounding_bias = vaddq_u32(lsb, vdupq_n_u32(0x7fff)); + + // input += rounding_bias + input = vaddq_u32(input, rounding_bias); + + // input = input >> 16 + input = vshrq_n_u32(input, 16); + + // Replace float-nans by bfloat16-nans, that is 0x7fc0 + const Packet4ui bf16_nan = vdupq_n_u32(0x7fc0); + const Packet4ui mask = vceqq_f32(p, p); + input = vbslq_u32(mask, input, bf16_nan); + + // output = static_cast(input) + return vmovn_u32(input); +} + +EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p) +{ + return reinterpret_cast(vshlq_n_u32(vmovl_u16(p), 16)); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pset1(const bfloat16& from) { + return pset1(from.value); +} + +template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet4bf& from) { + return bfloat16_impl::raw_uint16_to_bfloat16(static_cast(pfirst(from))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pload(const bfloat16* from) +{ + return pload(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Packet4bf ploadu(const bfloat16* from) +{ + return ploadu(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet4bf& from) +{ + EIGEN_DEBUG_ALIGNED_STORE vst1_u16(reinterpret_cast(to), from); +} + +template<> EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet4bf& from) +{ + EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(reinterpret_cast(to), from); +} + +template<> EIGEN_STRONG_INLINE Packet4bf ploaddup(const bfloat16* from) +{ + return ploaddup(reinterpret_cast(from)); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pabs(const Packet4bf& a) { + return F32ToBf16(pabs(Bf16ToF32(a))); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pmin(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pmax(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf por(const Packet4bf& a,const Packet4bf& b) { + return por(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pxor(const Packet4bf& a,const Packet4bf& b) { + return pxor(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pand(const Packet4bf& a,const Packet4bf& b) { + return pand(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pandnot(const Packet4bf& a,const Packet4bf& b) { + return pandnot(a, b); +} + +template<> EIGEN_DEVICE_FUNC inline Packet4bf pselect(const Packet4bf& mask, const Packet4bf& a, + const Packet4bf& b) +{ + return pselect(mask, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pfloor(const Packet4bf& a) +{ + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet4bf padd(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf psub(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pmul(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pdiv(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> +EIGEN_STRONG_INLINE Packet4bf pgather(const bfloat16* from, Index stride) +{ + return pgather(reinterpret_cast(from), stride); +} + +template<> +EIGEN_STRONG_INLINE void pscatter(bfloat16* to, const Packet4bf& from, Index stride) +{ + pscatter(reinterpret_cast(to), from, stride); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux(const Packet4bf& a) +{ + return static_cast(predux(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet4bf& a) +{ + return static_cast(predux_max(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet4bf& a) +{ + return static_cast(predux_min(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet4bf& a) +{ + return static_cast(predux_mul(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf preverse(const Packet4bf& a) +{ + return preverse(a); +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) +{ + PacketBlock k; + k.packet[0] = kernel.packet[0]; + k.packet[1] = kernel.packet[1]; + k.packet[2] = kernel.packet[2]; + k.packet[3] = kernel.packet[3]; + ptranspose(k); + kernel.packet[0] = k.packet[0]; + kernel.packet[1] = k.packet[1]; + kernel.packet[2] = k.packet[2]; + kernel.packet[3] = k.packet[3]; +} + +template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff(const Packet4bf& a, const Packet4bf& b) +{ + return F32ToBf16(pabsdiff(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq(const Packet4bf& a, const Packet4bf& b) +{ + return F32ToBf16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt(const Packet4bf& a, const Packet4bf& b) +{ + return F32ToBf16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le(const Packet4bf& a, const Packet4bf& b) +{ + return F32ToBf16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pnegate(const Packet4bf& a) +{ + return pxor(a, pset1(static_cast(0x8000))); +} + //---------- double ---------- // Clang 3.5 in the iOS toolchain has an ICE triggered by NEON intrisics for double. -- cgit v1.2.3