aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/NEON/PacketMath.h
diff options
context:
space:
mode:
authorGravatar David Tellenbach <david.tellenbach@me.com>2020-10-28 20:15:09 +0000
committerGravatar David Tellenbach <david.tellenbach@me.com>2020-10-28 20:15:09 +0000
commite265f7ed8e59c26e15f2c35162c6b8da1c5d594f (patch)
tree09f9696465ca75ecfdaeccda88358f397616042d /Eigen/src/Core/arch/NEON/PacketMath.h
parenta725a3233c98185eb3e5db6186aea3a906b8411f (diff)
Add support for Armv8.2-a __fp16
Armv8.2-a provides a native half-precision floating point (__fp16 aka. float16_t). This patch introduces * __fp16 as underlying type of Eigen::half if this type is available * the packet types Packet4hf and Packet8hf representing float16x4_t and float16x8_t respectively * packet-math for the above packets with corresponding scalar type Eigen::half The packet-math functionality has been implemented by Ashutosh Sharma <ashutosh.sharma@amperecomputing.com>. This closes #1940.
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h644
1 files changed, 644 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 6dbae8cee..dbfb1cdba 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -3771,6 +3771,650 @@ template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrt_
#endif // EIGEN_ARCH_ARM64
+// Do we have an fp16 types and supporting Neon intrinsics?
+#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+typedef float16x4_t Packet4hf;
+typedef float16x8_t Packet8hf;
+
+// TODO(tellenbach): Enable packets of size 8 as soon as the GEBP can handle them
+template <>
+struct packet_traits<Eigen::half> : default_packet_traits {
+ typedef Packet4hf type;
+ typedef Packet4hf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasCast = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasInsert = 1,
+ HasReduxp = 1,
+ HasDiv = 1,
+ HasFloor = 1,
+ HasSin = 0,
+ HasCos = 0,
+ HasLog = 0,
+ HasExp = 0,
+ HasSqrt = 1
+ };
+};
+
+template <>
+struct unpacket_traits<Packet4hf> {
+ typedef Eigen::half type;
+ typedef Packet4hf half;
+ enum {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+template <>
+struct unpacket_traits<Packet8hf> {
+ typedef Eigen::half type;
+ typedef Packet8hf half;
+ enum {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pset1<Packet8hf>(const Eigen::half& from) {
+ return vdupq_n_f16(from.x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pset1<Packet4hf>(const Eigen::half& from) {
+ return vdup_n_f16(from.x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf plset<Packet8hf>(const Eigen::half& a) {
+ const float16_t f[] = {0, 1, 2, 3, 4, 5, 6, 7};
+ Packet8hf countdown = vld1q_f16(f);
+ return vaddq_f16(pset1<Packet8hf>(a), countdown);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf plset<Packet4hf>(const Eigen::half& a) {
+ const float16_t f[] = {0, 1, 2, 3};
+ Packet4hf countdown = vld1_f16(f);
+ return vadd_f16(pset1<Packet4hf>(a), countdown);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf padd<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vaddq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf padd<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vadd_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf psub<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vsubq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf psub<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vsub_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pnegate(const Packet8hf& a) {
+ return vnegq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pnegate(const Packet4hf& a) {
+ return vneg_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pconj(const Packet8hf& a) {
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pconj(const Packet4hf& a) {
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmul<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vmulq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmul<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vmul_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pdiv<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vdivq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pdiv<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vdiv_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmadd(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
+ return vfmaq_f16(c, a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmadd(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
+ return vfma_f16(c, a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmin<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vminq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmin<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vmin_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmax<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vmaxq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmax<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vmax_f16(a, b);
+}
+
+#define EIGEN_MAKE_ARM_FP16_CMP_8(name) \
+ template <> \
+ EIGEN_STRONG_INLINE Packet8hf pcmp_##name(const Packet8hf& a, const Packet8hf& b) { \
+ return vreinterpretq_f16_u16(vc##name##q_f16(a, b)); \
+ }
+
+#define EIGEN_MAKE_ARM_FP16_CMP_4(name) \
+ template <> \
+ EIGEN_STRONG_INLINE Packet4hf pcmp_##name(const Packet4hf& a, const Packet4hf& b) { \
+ return vreinterpret_f16_u16(vc##name##_f16(a, b)); \
+ }
+
+EIGEN_MAKE_ARM_FP16_CMP_8(eq)
+EIGEN_MAKE_ARM_FP16_CMP_8(lt)
+EIGEN_MAKE_ARM_FP16_CMP_8(le)
+
+EIGEN_MAKE_ARM_FP16_CMP_4(eq)
+EIGEN_MAKE_ARM_FP16_CMP_4(lt)
+EIGEN_MAKE_ARM_FP16_CMP_4(le)
+
+#undef EIGEN_MAKE_ARM_FP16_CMP_8
+#undef EIGEN_MAKE_ARM_FP16_CMP_4
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pcmp_lt_or_nan<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vmvnq_u16(vcgeq_f16(a, b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vmvn_u16(vcge_f16(a, b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) {
+ const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
+ /* perform a floorf */
+ Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
+
+ /* if greater, substract 1 */
+ uint16x8_t mask = vcgtq_f16(tmp, a);
+ mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
+ return vsubq_f16(tmp, vreinterpretq_f16_u16(mask));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) {
+ const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
+ /* perform a floorf */
+ Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
+
+ /* if greater, substract 1 */
+ uint16x4_t mask = vcgt_f16(tmp, a);
+ mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
+ return vsub_f16(tmp, vreinterpret_f16_u16(mask));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf psqrt<Packet8hf>(const Packet8hf& a) {
+ return vsqrtq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf psqrt<Packet4hf>(const Packet4hf& a) {
+ return vsqrt_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pand<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pand<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vand_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf por<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf por<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vorr_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pxor<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pxor<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(veor_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pandnot<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vbicq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pandnot<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vbic_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pload<Packet8hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pload<Packet4hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ploadu<Packet8hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf ploadu<Packet4hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ploaddup<Packet8hf>(const Eigen::half* from) {
+ Packet8hf packet;
+ packet[0] = from[0].x;
+ packet[1] = from[0].x;
+ packet[2] = from[1].x;
+ packet[3] = from[1].x;
+ packet[4] = from[2].x;
+ packet[5] = from[2].x;
+ packet[6] = from[3].x;
+ packet[7] = from[3].x;
+ return packet;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf ploaddup<Packet4hf>(const Eigen::half* from) {
+ float16x4_t packet;
+ float16_t* tmp;
+ tmp = (float16_t*)&packet;
+ tmp[0] = from[0].x;
+ tmp[1] = from[0].x;
+ tmp[2] = from[1].x;
+ tmp[3] = from[1].x;
+ return packet;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ploadquad<Packet8hf>(const Eigen::half* from) {
+ Packet4hf lo, hi;
+ lo = vld1_dup_f16(reinterpret_cast<const float16_t*>(from));
+ hi = vld1_dup_f16(reinterpret_cast<const float16_t*>(from+1));
+ return vcombine_f16(lo, hi);
+}
+
+EIGEN_DEVICE_FUNC inline Packet8hf pinsertfirst(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 0); }
+
+EIGEN_DEVICE_FUNC inline Packet4hf pinsertfirst(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 0); }
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet8hf pselect(const Packet8hf& mask, const Packet8hf& a, const Packet8hf& b) {
+ return vbslq_f16(vreinterpretq_u16_f16(mask), a, b);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet4hf pselect(const Packet4hf& mask, const Packet4hf& a, const Packet4hf& b) {
+ return vbsl_f16(vreinterpret_u16_f16(mask), a, b);
+}
+
+EIGEN_DEVICE_FUNC inline Packet8hf pinsertlast(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 7); }
+
+EIGEN_DEVICE_FUNC inline Packet4hf pinsertlast(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 3); }
+
+template <>
+EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8hf& from) {
+ EIGEN_DEBUG_ALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4hf& from) {
+ EIGEN_DEBUG_ALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8hf& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4hf& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet8hf pgather<Eigen::half, Packet8hf>(const Eigen::half* from, Index stride) {
+ Packet8hf res = pset1<Packet8hf>(Eigen::half(0.f));
+ res = vsetq_lane_f16(from[0 * stride].x, res, 0);
+ res = vsetq_lane_f16(from[1 * stride].x, res, 1);
+ res = vsetq_lane_f16(from[2 * stride].x, res, 2);
+ res = vsetq_lane_f16(from[3 * stride].x, res, 3);
+ res = vsetq_lane_f16(from[4 * stride].x, res, 4);
+ res = vsetq_lane_f16(from[5 * stride].x, res, 5);
+ res = vsetq_lane_f16(from[6 * stride].x, res, 6);
+ res = vsetq_lane_f16(from[7 * stride].x, res, 7);
+ return res;
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet4hf pgather<Eigen::half, Packet4hf>(const Eigen::half* from, Index stride) {
+ Packet4hf res = pset1<Packet4hf>(Eigen::half(0.f));
+ res = vset_lane_f16(from[0 * stride].x, res, 0);
+ res = vset_lane_f16(from[1 * stride].x, res, 1);
+ res = vset_lane_f16(from[2 * stride].x, res, 2);
+ res = vset_lane_f16(from[3 * stride].x, res, 3);
+ return res;
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet8hf>(Eigen::half* to, const Packet8hf& from, Index stride) {
+ to[stride * 0].x = vgetq_lane_f16(from, 0);
+ to[stride * 1].x = vgetq_lane_f16(from, 1);
+ to[stride * 2].x = vgetq_lane_f16(from, 2);
+ to[stride * 3].x = vgetq_lane_f16(from, 3);
+ to[stride * 4].x = vgetq_lane_f16(from, 4);
+ to[stride * 5].x = vgetq_lane_f16(from, 5);
+ to[stride * 6].x = vgetq_lane_f16(from, 6);
+ to[stride * 7].x = vgetq_lane_f16(from, 7);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet4hf>(Eigen::half* to, const Packet4hf& from, Index stride) {
+ to[stride * 0].x = vget_lane_f16(from, 0);
+ to[stride * 1].x = vget_lane_f16(from, 1);
+ to[stride * 2].x = vget_lane_f16(from, 2);
+ to[stride * 3].x = vget_lane_f16(from, 3);
+}
+
+template <>
+EIGEN_STRONG_INLINE void prefetch<Eigen::half>(const Eigen::half* addr) {
+ EIGEN_ARM_PREFETCH(addr);
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8hf>(const Packet8hf& a) {
+ float16_t x[8];
+ vst1q_f16(x, a);
+ Eigen::half h;
+ h.x = x[0];
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4hf>(const Packet4hf& a) {
+ float16_t x[4];
+ vst1_f16(x, a);
+ Eigen::half h;
+ h.x = x[0];
+ return h;
+}
+
+template<> EIGEN_STRONG_INLINE Packet8hf preverse(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi;
+ Packet8hf a_r64;
+
+ a_r64 = vrev64q_f16(a);
+ a_lo = vget_low_f16(a_r64);
+ a_hi = vget_high_f16(a_r64);
+ return vcombine_f16(a_hi, a_lo);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf preverse<Packet4hf>(const Packet4hf& a) {
+ return vrev64_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pabs<Packet8hf>(const Packet8hf& a) {
+ return vabsq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pabs<Packet4hf>(const Packet4hf& a) {
+ return vabs_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, sum;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ sum = vpadd_f16(a_lo, a_hi);
+ sum = vpadd_f16(sum, sum);
+ sum = vpadd_f16(sum, sum);
+
+ Eigen::half h;
+ h.x = vget_lane_f16(sum, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux<Packet4hf>(const Packet4hf& a) {
+ float16x4_t sum;
+
+ sum = vpadd_f16(a, a);
+ sum = vpadd_f16(sum, sum);
+ Eigen::half h;
+ h.x = vget_lane_f16(sum, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, prod;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ prod = vmul_f16(a_lo, a_hi);
+ prod = vmul_f16(prod, vrev64_f16(prod));
+
+ Eigen::half h;
+ h.x = vget_lane_f16(prod, 0) * vget_lane_f16(prod, 1);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet4hf>(const Packet4hf& a) {
+ float16x4_t prod;
+ prod = vmul_f16(a, vrev64_f16(a));
+ Eigen::half h;
+ h.x = vget_lane_f16(prod, 0) * vget_lane_f16(prod, 1);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, min;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ min = vpmin_f16(a_lo, a_hi);
+ min = vpmin_f16(min, min);
+ min = vpmin_f16(min, min);
+
+ Eigen::half h;
+ h.x = vget_lane_f16(min, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_min<Packet4hf>(const Packet4hf& a) {
+ Packet4hf tmp;
+ tmp = vpmin_f16(a, a);
+ tmp = vpmin_f16(tmp, tmp);
+ Eigen::half h;
+ h.x = vget_lane_f16(tmp, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, max;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ max = vpmax_f16(a_lo, a_hi);
+ max = vpmax_f16(max, max);
+ max = vpmax_f16(max, max);
+
+ Eigen::half h;
+ h.x = vget_lane_f16(max, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_max<Packet4hf>(const Packet4hf& a) {
+ Packet4hf tmp;
+ tmp = vpmax_f16(a, a);
+ tmp = vpmax_f16(tmp, tmp);
+ Eigen::half h;
+ h.x = vget_lane_f16(tmp, 0);
+ return h;
+}
+
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8hf, 4>& kernel) {
+ EIGEN_ALIGN16 Eigen::half in[4][8];
+
+ pstore<Eigen::half>(in[0], kernel.packet[0]);
+ pstore<Eigen::half>(in[1], kernel.packet[1]);
+ pstore<Eigen::half>(in[2], kernel.packet[2]);
+ pstore<Eigen::half>(in[3], kernel.packet[3]);
+
+ EIGEN_ALIGN16 Eigen::half out[4][8];
+
+ EIGEN_UNROLL_LOOP
+ for (int i = 0; i < 4; ++i) {
+ EIGEN_UNROLL_LOOP
+ for (int j = 0; j < 4; ++j) {
+ out[i][j] = in[j][2*i];
+ }
+ EIGEN_UNROLL_LOOP
+ for (int j = 0; j < 4; ++j) {
+ out[i][j+4] = in[j][2*i+1];
+ }
+ }
+
+ kernel.packet[0] = pload<Packet8hf>(out[0]);
+ kernel.packet[1] = pload<Packet8hf>(out[1]);
+ kernel.packet[2] = pload<Packet8hf>(out[2]);
+ kernel.packet[3] = pload<Packet8hf>(out[3]);
+}
+
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4hf, 4>& kernel) {
+ EIGEN_ALIGN16 float16x4x4_t tmp_x4;
+ float16_t* tmp = (float16_t*)&kernel;
+ tmp_x4 = vld4_f16(tmp);
+
+ kernel.packet[0] = tmp_x4.val[0];
+ kernel.packet[1] = tmp_x4.val[1];
+ kernel.packet[2] = tmp_x4.val[2];
+ kernel.packet[3] = tmp_x4.val[3];
+}
+
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8hf, 8>& kernel) {
+ float16x8x2_t T_1[4];
+
+ T_1[0] = vuzpq_f16(kernel.packet[0], kernel.packet[1]);
+ T_1[1] = vuzpq_f16(kernel.packet[2], kernel.packet[3]);
+ T_1[2] = vuzpq_f16(kernel.packet[4], kernel.packet[5]);
+ T_1[3] = vuzpq_f16(kernel.packet[6], kernel.packet[7]);
+
+ float16x8x2_t T_2[4];
+ T_2[0] = vuzpq_f16(T_1[0].val[0], T_1[1].val[0]);
+ T_2[1] = vuzpq_f16(T_1[0].val[1], T_1[1].val[1]);
+ T_2[2] = vuzpq_f16(T_1[2].val[0], T_1[3].val[0]);
+ T_2[3] = vuzpq_f16(T_1[2].val[1], T_1[3].val[1]);
+
+ float16x8x2_t T_3[4];
+ T_3[0] = vuzpq_f16(T_2[0].val[0], T_2[2].val[0]);
+ T_3[1] = vuzpq_f16(T_2[0].val[1], T_2[2].val[1]);
+ T_3[2] = vuzpq_f16(T_2[1].val[0], T_2[3].val[0]);
+ T_3[3] = vuzpq_f16(T_2[1].val[1], T_2[3].val[1]);
+
+ kernel.packet[0] = T_3[0].val[0];
+ kernel.packet[1] = T_3[2].val[0];
+ kernel.packet[2] = T_3[1].val[0];
+ kernel.packet[3] = T_3[3].val[0];
+ kernel.packet[4] = T_3[0].val[1];
+ kernel.packet[5] = T_3[2].val[1];
+ kernel.packet[6] = T_3[1].val[1];
+ kernel.packet[7] = T_3[3].val[1];
+}
+#endif // end EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+
} // end namespace internal
} // end namespace Eigen