aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/NEON/PacketMath.h
diff options
context:
space:
mode:
authorGravatar David Tellenbach <david.tellenbach@me.com>2020-08-13 15:48:40 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-08-13 15:48:40 +0000
commit8ba1b0f41a7950dc3e1d4ed75859e36c73311235 (patch)
tree662ac7a45607d0527c925f22c78f0739aa9823d9 /Eigen/src/Core/arch/NEON/PacketMath.h
parent704798d1df4866be335ca013da19a44791f85a7e (diff)
bfloat16 packetmath for Arm Neon backend
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h263
1 files changed, 263 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index c2fdcbade..26d1d1298 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -3218,6 +3218,269 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) {
return res;
}
+//---------- bfloat16 ----------
+// TODO: Add support for native armv8.6-a bfloat16_t
+
+// TODO: Guard if we have native bfloat16 support
+typedef eigen_packet_wrapper<uint16x4_t, 19> Packet4bf;
+
+template<> struct is_arithmetic<Packet4bf> { enum { value = true }; };
+
+template<> struct packet_traits<bfloat16> : default_packet_traits
+{
+ typedef Packet4bf type;
+ typedef Packet4bf half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasDiv = 1,
+ HasFloor = 1,
+
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 0,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH
+ };
+};
+
+template<> struct unpacket_traits<Packet4bf>
+{
+ typedef bfloat16 type;
+ typedef Packet4bf half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p)
+{
+ // See the scalar implemention in BFloat16.h for a comprehensible explanation
+ // of this fast rounding algorithm
+ Packet4ui input = reinterpret_cast<Packet4ui>(p);
+
+ // lsb = (input >> 16) & 1
+ Packet4ui lsb = vandq_u32(vshrq_n_u32(input, 16), vdupq_n_u32(1));
+
+ // rounding_bias = 0x7fff + lsb
+ Packet4ui rounding_bias = vaddq_u32(lsb, vdupq_n_u32(0x7fff));
+
+ // input += rounding_bias
+ input = vaddq_u32(input, rounding_bias);
+
+ // input = input >> 16
+ input = vshrq_n_u32(input, 16);
+
+ // Replace float-nans by bfloat16-nans, that is 0x7fc0
+ const Packet4ui bf16_nan = vdupq_n_u32(0x7fc0);
+ const Packet4ui mask = vceqq_f32(p, p);
+ input = vbslq_u32(mask, input, bf16_nan);
+
+ // output = static_cast<uint16_t>(input)
+ return vmovn_u32(input);
+}
+
+EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p)
+{
+ return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) {
+ return pset1<Packet4us>(from.value);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet4bf>(const Packet4bf& from) {
+ return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<uint16_t>(pfirst<Packet4us>(from)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pload<Packet4bf>(const bfloat16* from)
+{
+ return pload<Packet4us>(reinterpret_cast<const uint16_t*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf ploadu<Packet4bf>(const bfloat16* from)
+{
+ return ploadu<Packet4us>(reinterpret_cast<const uint16_t*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet4bf& from)
+{
+ EIGEN_DEBUG_ALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet4bf& from)
+{
+ EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf ploaddup<Packet4bf>(const bfloat16* from)
+{
+ return ploaddup<Packet4us>(reinterpret_cast<const uint16_t*>(from));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pabs(const Packet4bf& a) {
+ return F32ToBf16(pabs<Packet4f>(Bf16ToF32(a)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmin<Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmin<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmax<Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmax<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf por(const Packet4bf& a,const Packet4bf& b) {
+ return por<Packet4us>(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pxor(const Packet4bf& a,const Packet4bf& b) {
+ return pxor<Packet4us>(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pand(const Packet4bf& a,const Packet4bf& b) {
+ return pand<Packet4us>(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pandnot(const Packet4bf& a,const Packet4bf& b) {
+ return pandnot<Packet4us>(a, b);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet4bf pselect(const Packet4bf& mask, const Packet4bf& a,
+ const Packet4bf& b)
+{
+ return pselect<Packet4us>(mask, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pfloor<Packet4bf>(const Packet4bf& a)
+{
+ return F32ToBf16(pfloor<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet4bf padd<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(padd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf psub<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(psub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pmul<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(pmul<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pdiv<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(pdiv<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<>
+EIGEN_STRONG_INLINE Packet4bf pgather<bfloat16, Packet4bf>(const bfloat16* from, Index stride)
+{
+ return pgather<uint16_t, Packet4us>(reinterpret_cast<const uint16_t*>(from), stride);
+}
+
+template<>
+EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet4bf>(bfloat16* to, const Packet4bf& from, Index stride)
+{
+ pscatter<uint16_t, Packet4us>(reinterpret_cast<uint16_t*>(to), from, stride);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux_max<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux_min<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux_mul<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf preverse<Packet4bf>(const Packet4bf& a)
+{
+ return preverse<Packet4us>(a);
+}
+
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4bf, 4>& kernel)
+{
+ PacketBlock<Packet4us, 4> k;
+ k.packet[0] = kernel.packet[0];
+ k.packet[1] = kernel.packet[1];
+ k.packet[2] = kernel.packet[2];
+ k.packet[3] = kernel.packet[3];
+ ptranspose(k);
+ kernel.packet[0] = k.packet[0];
+ kernel.packet[1] = k.packet[1];
+ kernel.packet[2] = k.packet[2];
+ kernel.packet[3] = k.packet[3];
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32ToBf16(pabsdiff<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32ToBf16(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32ToBf16(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32ToBf16(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a)
+{
+ return pxor<Packet4us>(a, pset1<Packet4us>(static_cast<uint16_t>(0x8000)));
+}
+
//---------- double ----------
// Clang 3.5 in the iOS toolchain has an ICE triggered by NEON intrisics for double.