diff options
author | David Tellenbach <david.tellenbach@me.com> | 2020-10-28 20:15:09 +0000 |
---|---|---|
committer | David Tellenbach <david.tellenbach@me.com> | 2020-10-28 20:15:09 +0000 |
commit | e265f7ed8e59c26e15f2c35162c6b8da1c5d594f (patch) | |
tree | 09f9696465ca75ecfdaeccda88358f397616042d /Eigen/src/Core/arch/NEON/PacketMath.h | |
parent | a725a3233c98185eb3e5db6186aea3a906b8411f (diff) |
Add support for Armv8.2-a __fp16
Armv8.2-a provides a native half-precision floating point (__fp16 aka.
float16_t). This patch introduces
* __fp16 as underlying type of Eigen::half if this type is available
* the packet types Packet4hf and Packet8hf representing float16x4_t and
float16x8_t respectively
* packet-math for the above packets with corresponding scalar type Eigen::half
The packet-math functionality has been implemented by Ashutosh Sharma
<ashutosh.sharma@amperecomputing.com>.
This closes #1940.
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r-- | Eigen/src/Core/arch/NEON/PacketMath.h | 644 |
1 files changed, 644 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 6dbae8cee..dbfb1cdba 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -3771,6 +3771,650 @@ template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrt_ #endif // EIGEN_ARCH_ARM64 +// Do we have an fp16 types and supporting Neon intrinsics? +#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC +typedef float16x4_t Packet4hf; +typedef float16x8_t Packet8hf; + +// TODO(tellenbach): Enable packets of size 8 as soon as the GEBP can handle them +template <> +struct packet_traits<Eigen::half> : default_packet_traits { + typedef Packet4hf type; + typedef Packet4hf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, + + HasCmp = 1, + HasCast = 1, + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasInsert = 1, + HasReduxp = 1, + HasDiv = 1, + HasFloor = 1, + HasSin = 0, + HasCos = 0, + HasLog = 0, + HasExp = 0, + HasSqrt = 1 + }; +}; + +template <> +struct unpacket_traits<Packet4hf> { + typedef Eigen::half type; + typedef Packet4hf half; + enum { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template <> +struct unpacket_traits<Packet8hf> { + typedef Eigen::half type; + typedef Packet8hf half; + enum { + size = 8, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template <> +EIGEN_STRONG_INLINE Packet8hf pset1<Packet8hf>(const Eigen::half& from) { + return vdupq_n_f16(from.x); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pset1<Packet4hf>(const Eigen::half& from) { + return vdup_n_f16(from.x); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf plset<Packet8hf>(const Eigen::half& a) { + const float16_t f[] = {0, 1, 2, 3, 4, 5, 6, 7}; + Packet8hf countdown = vld1q_f16(f); + return vaddq_f16(pset1<Packet8hf>(a), countdown); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf plset<Packet4hf>(const Eigen::half& a) { + const float16_t f[] = {0, 1, 2, 3}; + Packet4hf countdown = vld1_f16(f); + return vadd_f16(pset1<Packet4hf>(a), countdown); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf padd<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vaddq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf padd<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vadd_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf psub<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vsubq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf psub<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vsub_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pnegate(const Packet8hf& a) { + return vnegq_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pnegate(const Packet4hf& a) { + return vneg_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pconj(const Packet8hf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pconj(const Packet4hf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pmul<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vmulq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pmul<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vmul_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pdiv<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vdivq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pdiv<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vdiv_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pmadd(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) { + return vfmaq_f16(c, a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pmadd(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) { + return vfma_f16(c, a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pmin<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vminq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pmin<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vmin_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pmax<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vmaxq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pmax<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vmax_f16(a, b); +} + +#define EIGEN_MAKE_ARM_FP16_CMP_8(name) \ + template <> \ + EIGEN_STRONG_INLINE Packet8hf pcmp_##name(const Packet8hf& a, const Packet8hf& b) { \ + return vreinterpretq_f16_u16(vc##name##q_f16(a, b)); \ + } + +#define EIGEN_MAKE_ARM_FP16_CMP_4(name) \ + template <> \ + EIGEN_STRONG_INLINE Packet4hf pcmp_##name(const Packet4hf& a, const Packet4hf& b) { \ + return vreinterpret_f16_u16(vc##name##_f16(a, b)); \ + } + +EIGEN_MAKE_ARM_FP16_CMP_8(eq) +EIGEN_MAKE_ARM_FP16_CMP_8(lt) +EIGEN_MAKE_ARM_FP16_CMP_8(le) + +EIGEN_MAKE_ARM_FP16_CMP_4(eq) +EIGEN_MAKE_ARM_FP16_CMP_4(lt) +EIGEN_MAKE_ARM_FP16_CMP_4(le) + +#undef EIGEN_MAKE_ARM_FP16_CMP_8 +#undef EIGEN_MAKE_ARM_FP16_CMP_4 + +template <> +EIGEN_STRONG_INLINE Packet8hf pcmp_lt_or_nan<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(vmvnq_u16(vcgeq_f16(a, b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(vmvn_u16(vcge_f16(a, b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) { + const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f)); + /* perform a floorf */ + Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a)); + + /* if greater, substract 1 */ + uint16x8_t mask = vcgtq_f16(tmp, a); + mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1)); + return vsubq_f16(tmp, vreinterpretq_f16_u16(mask)); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) { + const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f)); + /* perform a floorf */ + Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a)); + + /* if greater, substract 1 */ + uint16x4_t mask = vcgt_f16(tmp, a); + mask = vand_u16(mask, vreinterpret_u16_f16(cst_1)); + return vsub_f16(tmp, vreinterpret_f16_u16(mask)); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf psqrt<Packet8hf>(const Packet8hf& a) { + return vsqrtq_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf psqrt<Packet4hf>(const Packet4hf& a) { + return vsqrt_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pand<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pand<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(vand_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf por<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf por<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(vorr_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pxor<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pxor<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(veor_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pandnot<Packet8hf>(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(vbicq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pandnot<Packet4hf>(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(vbic_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pload<Packet8hf>(const Eigen::half* from) { + EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pload<Packet4hf>(const Eigen::half* from) { + EIGEN_DEBUG_ALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf ploadu<Packet8hf>(const Eigen::half* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf ploadu<Packet4hf>(const Eigen::half* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf ploaddup<Packet8hf>(const Eigen::half* from) { + Packet8hf packet; + packet[0] = from[0].x; + packet[1] = from[0].x; + packet[2] = from[1].x; + packet[3] = from[1].x; + packet[4] = from[2].x; + packet[5] = from[2].x; + packet[6] = from[3].x; + packet[7] = from[3].x; + return packet; +} + +template <> +EIGEN_STRONG_INLINE Packet4hf ploaddup<Packet4hf>(const Eigen::half* from) { + float16x4_t packet; + float16_t* tmp; + tmp = (float16_t*)&packet; + tmp[0] = from[0].x; + tmp[1] = from[0].x; + tmp[2] = from[1].x; + tmp[3] = from[1].x; + return packet; +} + +template <> +EIGEN_STRONG_INLINE Packet8hf ploadquad<Packet8hf>(const Eigen::half* from) { + Packet4hf lo, hi; + lo = vld1_dup_f16(reinterpret_cast<const float16_t*>(from)); + hi = vld1_dup_f16(reinterpret_cast<const float16_t*>(from+1)); + return vcombine_f16(lo, hi); +} + +EIGEN_DEVICE_FUNC inline Packet8hf pinsertfirst(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 0); } + +EIGEN_DEVICE_FUNC inline Packet4hf pinsertfirst(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 0); } + +template <> +EIGEN_DEVICE_FUNC inline Packet8hf pselect(const Packet8hf& mask, const Packet8hf& a, const Packet8hf& b) { + return vbslq_f16(vreinterpretq_u16_f16(mask), a, b); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet4hf pselect(const Packet4hf& mask, const Packet4hf& a, const Packet4hf& b) { + return vbsl_f16(vreinterpret_u16_f16(mask), a, b); +} + +EIGEN_DEVICE_FUNC inline Packet8hf pinsertlast(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 7); } + +EIGEN_DEVICE_FUNC inline Packet4hf pinsertlast(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 3); } + +template <> +EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8hf& from) { + EIGEN_DEBUG_ALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from); +} + +template <> +EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4hf& from) { + EIGEN_DEBUG_ALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8hf& from) { + EIGEN_DEBUG_UNALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4hf& from) { + EIGEN_DEBUG_UNALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet8hf pgather<Eigen::half, Packet8hf>(const Eigen::half* from, Index stride) { + Packet8hf res = pset1<Packet8hf>(Eigen::half(0.f)); + res = vsetq_lane_f16(from[0 * stride].x, res, 0); + res = vsetq_lane_f16(from[1 * stride].x, res, 1); + res = vsetq_lane_f16(from[2 * stride].x, res, 2); + res = vsetq_lane_f16(from[3 * stride].x, res, 3); + res = vsetq_lane_f16(from[4 * stride].x, res, 4); + res = vsetq_lane_f16(from[5 * stride].x, res, 5); + res = vsetq_lane_f16(from[6 * stride].x, res, 6); + res = vsetq_lane_f16(from[7 * stride].x, res, 7); + return res; +} + +template <> +EIGEN_DEVICE_FUNC inline Packet4hf pgather<Eigen::half, Packet4hf>(const Eigen::half* from, Index stride) { + Packet4hf res = pset1<Packet4hf>(Eigen::half(0.f)); + res = vset_lane_f16(from[0 * stride].x, res, 0); + res = vset_lane_f16(from[1 * stride].x, res, 1); + res = vset_lane_f16(from[2 * stride].x, res, 2); + res = vset_lane_f16(from[3 * stride].x, res, 3); + return res; +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet8hf>(Eigen::half* to, const Packet8hf& from, Index stride) { + to[stride * 0].x = vgetq_lane_f16(from, 0); + to[stride * 1].x = vgetq_lane_f16(from, 1); + to[stride * 2].x = vgetq_lane_f16(from, 2); + to[stride * 3].x = vgetq_lane_f16(from, 3); + to[stride * 4].x = vgetq_lane_f16(from, 4); + to[stride * 5].x = vgetq_lane_f16(from, 5); + to[stride * 6].x = vgetq_lane_f16(from, 6); + to[stride * 7].x = vgetq_lane_f16(from, 7); +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet4hf>(Eigen::half* to, const Packet4hf& from, Index stride) { + to[stride * 0].x = vget_lane_f16(from, 0); + to[stride * 1].x = vget_lane_f16(from, 1); + to[stride * 2].x = vget_lane_f16(from, 2); + to[stride * 3].x = vget_lane_f16(from, 3); +} + +template <> +EIGEN_STRONG_INLINE void prefetch<Eigen::half>(const Eigen::half* addr) { + EIGEN_ARM_PREFETCH(addr); +} + +template <> +EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8hf>(const Packet8hf& a) { + float16_t x[8]; + vst1q_f16(x, a); + Eigen::half h; + h.x = x[0]; + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4hf>(const Packet4hf& a) { + float16_t x[4]; + vst1_f16(x, a); + Eigen::half h; + h.x = x[0]; + return h; +} + +template<> EIGEN_STRONG_INLINE Packet8hf preverse(const Packet8hf& a) { + float16x4_t a_lo, a_hi; + Packet8hf a_r64; + + a_r64 = vrev64q_f16(a); + a_lo = vget_low_f16(a_r64); + a_hi = vget_high_f16(a_r64); + return vcombine_f16(a_hi, a_lo); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf preverse<Packet4hf>(const Packet4hf& a) { + return vrev64_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pabs<Packet8hf>(const Packet8hf& a) { + return vabsq_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pabs<Packet4hf>(const Packet4hf& a) { + return vabs_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux<Packet8hf>(const Packet8hf& a) { + float16x4_t a_lo, a_hi, sum; + + a_lo = vget_low_f16(a); + a_hi = vget_high_f16(a); + sum = vpadd_f16(a_lo, a_hi); + sum = vpadd_f16(sum, sum); + sum = vpadd_f16(sum, sum); + + Eigen::half h; + h.x = vget_lane_f16(sum, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux<Packet4hf>(const Packet4hf& a) { + float16x4_t sum; + + sum = vpadd_f16(a, a); + sum = vpadd_f16(sum, sum); + Eigen::half h; + h.x = vget_lane_f16(sum, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8hf>(const Packet8hf& a) { + float16x4_t a_lo, a_hi, prod; + + a_lo = vget_low_f16(a); + a_hi = vget_high_f16(a); + prod = vmul_f16(a_lo, a_hi); + prod = vmul_f16(prod, vrev64_f16(prod)); + + Eigen::half h; + h.x = vget_lane_f16(prod, 0) * vget_lane_f16(prod, 1); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet4hf>(const Packet4hf& a) { + float16x4_t prod; + prod = vmul_f16(a, vrev64_f16(a)); + Eigen::half h; + h.x = vget_lane_f16(prod, 0) * vget_lane_f16(prod, 1); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8hf>(const Packet8hf& a) { + float16x4_t a_lo, a_hi, min; + + a_lo = vget_low_f16(a); + a_hi = vget_high_f16(a); + min = vpmin_f16(a_lo, a_hi); + min = vpmin_f16(min, min); + min = vpmin_f16(min, min); + + Eigen::half h; + h.x = vget_lane_f16(min, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_min<Packet4hf>(const Packet4hf& a) { + Packet4hf tmp; + tmp = vpmin_f16(a, a); + tmp = vpmin_f16(tmp, tmp); + Eigen::half h; + h.x = vget_lane_f16(tmp, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8hf>(const Packet8hf& a) { + float16x4_t a_lo, a_hi, max; + + a_lo = vget_low_f16(a); + a_hi = vget_high_f16(a); + max = vpmax_f16(a_lo, a_hi); + max = vpmax_f16(max, max); + max = vpmax_f16(max, max); + + Eigen::half h; + h.x = vget_lane_f16(max, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_max<Packet4hf>(const Packet4hf& a) { + Packet4hf tmp; + tmp = vpmax_f16(a, a); + tmp = vpmax_f16(tmp, tmp); + Eigen::half h; + h.x = vget_lane_f16(tmp, 0); + return h; +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8hf, 4>& kernel) { + EIGEN_ALIGN16 Eigen::half in[4][8]; + + pstore<Eigen::half>(in[0], kernel.packet[0]); + pstore<Eigen::half>(in[1], kernel.packet[1]); + pstore<Eigen::half>(in[2], kernel.packet[2]); + pstore<Eigen::half>(in[3], kernel.packet[3]); + + EIGEN_ALIGN16 Eigen::half out[4][8]; + + EIGEN_UNROLL_LOOP + for (int i = 0; i < 4; ++i) { + EIGEN_UNROLL_LOOP + for (int j = 0; j < 4; ++j) { + out[i][j] = in[j][2*i]; + } + EIGEN_UNROLL_LOOP + for (int j = 0; j < 4; ++j) { + out[i][j+4] = in[j][2*i+1]; + } + } + + kernel.packet[0] = pload<Packet8hf>(out[0]); + kernel.packet[1] = pload<Packet8hf>(out[1]); + kernel.packet[2] = pload<Packet8hf>(out[2]); + kernel.packet[3] = pload<Packet8hf>(out[3]); +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4hf, 4>& kernel) { + EIGEN_ALIGN16 float16x4x4_t tmp_x4; + float16_t* tmp = (float16_t*)&kernel; + tmp_x4 = vld4_f16(tmp); + + kernel.packet[0] = tmp_x4.val[0]; + kernel.packet[1] = tmp_x4.val[1]; + kernel.packet[2] = tmp_x4.val[2]; + kernel.packet[3] = tmp_x4.val[3]; +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8hf, 8>& kernel) { + float16x8x2_t T_1[4]; + + T_1[0] = vuzpq_f16(kernel.packet[0], kernel.packet[1]); + T_1[1] = vuzpq_f16(kernel.packet[2], kernel.packet[3]); + T_1[2] = vuzpq_f16(kernel.packet[4], kernel.packet[5]); + T_1[3] = vuzpq_f16(kernel.packet[6], kernel.packet[7]); + + float16x8x2_t T_2[4]; + T_2[0] = vuzpq_f16(T_1[0].val[0], T_1[1].val[0]); + T_2[1] = vuzpq_f16(T_1[0].val[1], T_1[1].val[1]); + T_2[2] = vuzpq_f16(T_1[2].val[0], T_1[3].val[0]); + T_2[3] = vuzpq_f16(T_1[2].val[1], T_1[3].val[1]); + + float16x8x2_t T_3[4]; + T_3[0] = vuzpq_f16(T_2[0].val[0], T_2[2].val[0]); + T_3[1] = vuzpq_f16(T_2[0].val[1], T_2[2].val[1]); + T_3[2] = vuzpq_f16(T_2[1].val[0], T_2[3].val[0]); + T_3[3] = vuzpq_f16(T_2[1].val[1], T_2[3].val[1]); + + kernel.packet[0] = T_3[0].val[0]; + kernel.packet[1] = T_3[2].val[0]; + kernel.packet[2] = T_3[1].val[0]; + kernel.packet[3] = T_3[3].val[0]; + kernel.packet[4] = T_3[0].val[1]; + kernel.packet[5] = T_3[2].val[1]; + kernel.packet[6] = T_3[1].val[1]; + kernel.packet[7] = T_3[3].val[1]; +} +#endif // end EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC + } // end namespace internal } // end namespace Eigen |