aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-05-24 21:34:35 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-05-25 18:25:35 +0000
commitdba753a986b527a17c8cc62474d0487aec7c2b36 (patch)
tree70760b526cc358e8c3dc107072d36277ad8228dc
parentebb300d0b4340104dcade3afa656a57da2b7660c (diff)
Add missing NEON ptranspose implementations.
Unified implementation using only `vzip`.
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h507
-rw-r--r--test/packetmath.cpp32
2 files changed, 226 insertions, 313 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 2b48570d1..73a35c570 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -2774,352 +2774,265 @@ template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x)
return vget_lane_u32(vpmax_u32(tmp, tmp), 0);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2f, 2>& kernel)
-{
- const float32x2x2_t z = vzip_f32(kernel.packet[0], kernel.packet[1]);
- kernel.packet[0] = z.val[0];
- kernel.packet[1] = z.val[1];
+// Helpers for ptranspose.
+namespace detail {
+
+template<typename Packet>
+void zip_in_place(Packet& p1, Packet& p2);
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet2f>(Packet2f& p1, Packet2f& p2) {
+ const float32x2x2_t tmp = vzip_f32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4f, 4>& kernel)
-{
- const float32x4x2_t tmp1 = vzipq_f32(kernel.packet[0], kernel.packet[1]);
- const float32x4x2_t tmp2 = vzipq_f32(kernel.packet[2], kernel.packet[3]);
- kernel.packet[0] = vcombine_f32(vget_low_f32(tmp1.val[0]), vget_low_f32(tmp2.val[0]));
- kernel.packet[1] = vcombine_f32(vget_high_f32(tmp1.val[0]), vget_high_f32(tmp2.val[0]));
- kernel.packet[2] = vcombine_f32(vget_low_f32(tmp1.val[1]), vget_low_f32(tmp2.val[1]));
- kernel.packet[3] = vcombine_f32(vget_high_f32(tmp1.val[1]), vget_high_f32(tmp2.val[1]));
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4f>(Packet4f& p1, Packet4f& p2) {
+ const float32x4x2_t tmp = vzipq_f32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4c, 4>& kernel)
-{
- const int8x8_t a = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[2], vdup_n_s32(kernel.packet[0]), 1));
- const int8x8_t b = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[3], vdup_n_s32(kernel.packet[1]), 1));
- const int8x8x2_t zip8 = vzip_s8(a,b);
- const int16x4x2_t zip16 = vzip_s16(vreinterpret_s16_s8(zip8.val[0]), vreinterpret_s16_s8(zip8.val[1]));
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8c>(Packet8c& p1, Packet8c& p2) {
+ const int8x8x2_t tmp = vzip_s8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
- kernel.packet[0] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 0);
- kernel.packet[1] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 1);
- kernel.packet[2] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 0);
- kernel.packet[3] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 1);
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet16c>(Packet16c& p1, Packet16c& p2) {
+ const int8x16x2_t tmp = vzipq_s8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8c, 8>& kernel)
-{
- int8x8x2_t zip8[4];
- uint16x4x2_t zip16[4];
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 4; i++)
- zip8[i] = vzip_s8(kernel.packet[i*2], kernel.packet[i*2+1]);
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8uc>(Packet8uc& p1, Packet8uc& p2) {
+ const uint8x8x2_t tmp = vzip_u8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 2; i++)
- {
- EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- zip16[i*2+j] = vzip_u16(vreinterpret_u16_s8(zip8[i*2].val[j]), vreinterpret_u16_s8(zip8[i*2+1].val[j]));
- }
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet16uc>(Packet16uc& p1, Packet16uc& p2) {
+ const uint8x16x2_t tmp = vzipq_u8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 2; i++)
- {
- EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- {
- const uint32x2x2_t z = vzip_u32(vreinterpret_u32_u16(zip16[i].val[j]), vreinterpret_u32_u16(zip16[i+2].val[j]));
- EIGEN_UNROLL_LOOP
- for (int k = 0; k != 2; k++)
- kernel.packet[i*4+j*2+k] = vreinterpret_s8_u32(z.val[k]);
- }
- }
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet2i>(Packet2i& p1, Packet2i& p2) {
+ const int32x2x2_t tmp = vzip_s32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 16>& kernel)
-{
- int8x16x2_t zip8[8];
- uint16x8x2_t zip16[8];
- uint32x4x2_t zip32[8];
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 8; i++)
- zip8[i] = vzipq_s8(kernel.packet[i*2], kernel.packet[i*2+1]);
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4i>(Packet4i& p1, Packet4i& p2) {
+ const int32x4x2_t tmp = vzipq_s32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 4; i++)
- {
- EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- {
- zip16[i*2+j] = vzipq_u16(vreinterpretq_u16_s8(zip8[i*2].val[j]),
- vreinterpretq_u16_s8(zip8[i*2+1].val[j]));
- }
- }
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet2ui>(Packet2ui& p1, Packet2ui& p2) {
+ const uint32x2x2_t tmp = vzip_u32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 2; i++)
- {
- EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- {
- EIGEN_UNROLL_LOOP
- for (int k = 0; k != 2; k++)
- zip32[i*4+j*2+k] = vzipq_u32(vreinterpretq_u32_u16(zip16[i*4+j].val[k]),
- vreinterpretq_u32_u16(zip16[i*4+j+2].val[k]));
- }
- }
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4ui>(Packet4ui& p1, Packet4ui& p2) {
+ const uint32x4x2_t tmp = vzipq_u32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 4; i++)
- {
- EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- {
- kernel.packet[i*4+j*2] = vreinterpretq_s8_u32(vcombine_u32(vget_low_u32(zip32[i].val[j]),
- vget_low_u32(zip32[i+4].val[j])));
- kernel.packet[i*4+j*2+1] = vreinterpretq_s8_u32(vcombine_u32(vget_high_u32(zip32[i].val[j]),
- vget_high_u32(zip32[i+4].val[j])));
- }
- }
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4s>(Packet4s& p1, Packet4s& p2) {
+ const int16x4x2_t tmp = vzip_s16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4uc, 4>& kernel)
-{
- const uint8x8_t a = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[2], vdup_n_u32(kernel.packet[0]), 1));
- const uint8x8_t b = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[3], vdup_n_u32(kernel.packet[1]), 1));
- const uint8x8x2_t zip8 = vzip_u8(a,b);
- const uint16x4x2_t zip16 = vzip_u16(vreinterpret_u16_u8(zip8.val[0]), vreinterpret_u16_u8(zip8.val[1]));
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8s>(Packet8s& p1, Packet8s& p2) {
+ const int16x8x2_t tmp = vzipq_s16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
- kernel.packet[0] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 0);
- kernel.packet[1] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 1);
- kernel.packet[2] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 0);
- kernel.packet[3] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 1);
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4us>(Packet4us& p1, Packet4us& p2) {
+ const uint16x4x2_t tmp = vzip_u16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8uc, 8>& kernel)
-{
- uint8x8x2_t zip8[4];
- uint16x4x2_t zip16[4];
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 4; i++)
- zip8[i] = vzip_u8(kernel.packet[i*2], kernel.packet[i*2+1]);
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8us>(Packet8us& p1, Packet8us& p2) {
+ const uint16x8x2_t tmp = vzipq_u16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 2; i++)
- {
- EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- zip16[i*2+j] = vzip_u16(vreinterpret_u16_u8(zip8[i*2].val[j]), vreinterpret_u16_u8(zip8[i*2+1].val[j]));
- }
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 2>& kernel) {
+ zip_in_place(kernel.packet[0], kernel.packet[1]);
+}
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 2; i++)
- {
- EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- {
- const uint32x2x2_t z = vzip_u32(vreinterpret_u32_u16(zip16[i].val[j]), vreinterpret_u32_u16(zip16[i+2].val[j]));
- EIGEN_UNROLL_LOOP
- for (int k = 0; k != 2; k++)
- kernel.packet[i*4+j*2+k] = vreinterpret_u8_u32(z.val[k]);
- }
- }
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 4>& kernel) {
+ zip_in_place(kernel.packet[0], kernel.packet[2]);
+ zip_in_place(kernel.packet[1], kernel.packet[3]);
+ zip_in_place(kernel.packet[0], kernel.packet[1]);
+ zip_in_place(kernel.packet[2], kernel.packet[3]);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 16>& kernel)
-{
- uint8x16x2_t zip8[8];
- uint16x8x2_t zip16[8];
- uint32x4x2_t zip32[8];
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 8; i++)
- zip8[i] = vzipq_u8(kernel.packet[i*2], kernel.packet[i*2+1]);
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 8>& kernel) {
+ zip_in_place(kernel.packet[0], kernel.packet[4]);
+ zip_in_place(kernel.packet[1], kernel.packet[5]);
+ zip_in_place(kernel.packet[2], kernel.packet[6]);
+ zip_in_place(kernel.packet[3], kernel.packet[7]);
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 4; i++)
- {
- EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- zip16[i*2+j] = vzipq_u16(vreinterpretq_u16_u8(zip8[i*2].val[j]),
- vreinterpretq_u16_u8(zip8[i*2+1].val[j]));
- }
+ zip_in_place(kernel.packet[0], kernel.packet[2]);
+ zip_in_place(kernel.packet[1], kernel.packet[3]);
+ zip_in_place(kernel.packet[4], kernel.packet[6]);
+ zip_in_place(kernel.packet[5], kernel.packet[7]);
+
+ zip_in_place(kernel.packet[0], kernel.packet[1]);
+ zip_in_place(kernel.packet[2], kernel.packet[3]);
+ zip_in_place(kernel.packet[4], kernel.packet[5]);
+ zip_in_place(kernel.packet[6], kernel.packet[7]);
+}
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 16>& kernel) {
EIGEN_UNROLL_LOOP
- for (int i = 0; i != 2; i++)
- {
+ for (int i=0; i<4; ++i) {
+ const int m = (1 << i);
EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- {
+ for (int j=0; j<m; ++j) {
+ const int n = (1 << (3-i));
EIGEN_UNROLL_LOOP
- for (int k = 0; k != 2; k++)
- zip32[i*4+j*2+k] = vzipq_u32(vreinterpretq_u32_u16(zip16[i*4+j].val[k]),
- vreinterpretq_u32_u16(zip16[i*4+j+2].val[k]));
- }
- }
-
- EIGEN_UNROLL_LOOP
- for (int i = 0; i != 4; i++)
- {
- EIGEN_UNROLL_LOOP
- for (int j = 0; j != 2; j++)
- {
- kernel.packet[i*4+j*2] = vreinterpretq_u8_u32(vcombine_u32(vget_low_u32(zip32[i].val[j]),
- vget_low_u32(zip32[i+4].val[j])));
- kernel.packet[i*4+j*2+1] = vreinterpretq_u8_u32(vcombine_u32(vget_high_u32(zip32[i].val[j]),
- vget_high_u32(zip32[i+4].val[j])));
+ for (int k=0; k<n; ++k) {
+ const int idx = 2*j*n+k;
+ zip_in_place(kernel.packet[idx], kernel.packet[idx + n]);
+ }
}
}
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4s, 4>& kernel)
-{
- const int16x4x2_t zip16_1 = vzip_s16(kernel.packet[0], kernel.packet[1]);
- const int16x4x2_t zip16_2 = vzip_s16(kernel.packet[2], kernel.packet[3]);
- const uint32x2x2_t zip32_1 = vzip_u32(vreinterpret_u32_s16(zip16_1.val[0]), vreinterpret_u32_s16(zip16_2.val[0]));
- const uint32x2x2_t zip32_2 = vzip_u32(vreinterpret_u32_s16(zip16_1.val[1]), vreinterpret_u32_s16(zip16_2.val[1]));
+} // namespace detail
- kernel.packet[0] = vreinterpret_s16_u32(zip32_1.val[0]);
- kernel.packet[1] = vreinterpret_s16_u32(zip32_1.val[1]);
- kernel.packet[2] = vreinterpret_s16_u32(zip32_2.val[0]);
- kernel.packet[3] = vreinterpret_s16_u32(zip32_2.val[1]);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2f, 2>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8s, 4>& kernel)
-{
- const int16x8x2_t zip16_1 = vzipq_s16(kernel.packet[0], kernel.packet[1]);
- const int16x8x2_t zip16_2 = vzipq_s16(kernel.packet[2], kernel.packet[3]);
-
- const uint32x4x2_t zip32_1 = vzipq_u32(vreinterpretq_u32_s16(zip16_1.val[0]), vreinterpretq_u32_s16(zip16_2.val[0]));
- const uint32x4x2_t zip32_2 = vzipq_u32(vreinterpretq_u32_s16(zip16_1.val[1]), vreinterpretq_u32_s16(zip16_2.val[1]));
-
- kernel.packet[0] = vreinterpretq_s16_u32(zip32_1.val[0]);
- kernel.packet[1] = vreinterpretq_s16_u32(zip32_1.val[1]);
- kernel.packet[2] = vreinterpretq_s16_u32(zip32_2.val[0]);
- kernel.packet[3] = vreinterpretq_s16_u32(zip32_2.val[1]);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4f, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 4>& kernel)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4c, 4>& kernel)
{
- const int8x16x2_t zip8_1 = vzipq_s8(kernel.packet[0], kernel.packet[1]);
- const int8x16x2_t zip8_2 = vzipq_s8(kernel.packet[2], kernel.packet[3]);
+ const int8x8_t a = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[2], vdup_n_s32(kernel.packet[0]), 1));
+ const int8x8_t b = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[3], vdup_n_s32(kernel.packet[1]), 1));
- const int16x8x2_t zip16_1 = vzipq_s16(vreinterpretq_s16_s8(zip8_1.val[0]), vreinterpretq_s16_s8(zip8_2.val[0]));
- const int16x8x2_t zip16_2 = vzipq_s16(vreinterpretq_s16_s8(zip8_1.val[1]), vreinterpretq_s16_s8(zip8_2.val[1]));
+ const int8x8x2_t zip8 = vzip_s8(a,b);
+ const int16x4x2_t zip16 = vzip_s16(vreinterpret_s16_s8(zip8.val[0]), vreinterpret_s16_s8(zip8.val[1]));
- kernel.packet[0] = vreinterpretq_s8_s16(zip16_1.val[0]);
- kernel.packet[1] = vreinterpretq_s8_s16(zip16_1.val[1]);
- kernel.packet[2] = vreinterpretq_s8_s16(zip16_2.val[0]);
- kernel.packet[3] = vreinterpretq_s8_s16(zip16_2.val[1]);
+ kernel.packet[0] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 0);
+ kernel.packet[1] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 1);
+ kernel.packet[2] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 0);
+ kernel.packet[3] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 1);
}
-
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 4>& kernel)
-{
- const uint8x16x2_t zip8_1 = vzipq_u8(kernel.packet[0], kernel.packet[1]);
- const uint8x16x2_t zip8_2 = vzipq_u8(kernel.packet[2], kernel.packet[3]);
-
- const uint16x8x2_t zip16_1 = vzipq_u16(vreinterpretq_u16_u8(zip8_1.val[0]), vreinterpretq_u16_u8(zip8_2.val[0]));
- const uint16x8x2_t zip16_2 = vzipq_u16(vreinterpretq_u16_u8(zip8_1.val[1]), vreinterpretq_u16_u8(zip8_2.val[1]));
-
- kernel.packet[0] = vreinterpretq_u8_u16(zip16_1.val[0]);
- kernel.packet[1] = vreinterpretq_u8_u16(zip16_1.val[1]);
- kernel.packet[2] = vreinterpretq_u8_u16(zip16_2.val[0]);
- kernel.packet[3] = vreinterpretq_u8_u16(zip16_2.val[1]);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8c, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8c, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 16>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8s, 8>& kernel)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4uc, 4>& kernel)
{
- const int16x8x2_t zip16_1 = vzipq_s16(kernel.packet[0], kernel.packet[1]);
- const int16x8x2_t zip16_2 = vzipq_s16(kernel.packet[2], kernel.packet[3]);
- const int16x8x2_t zip16_3 = vzipq_s16(kernel.packet[4], kernel.packet[5]);
- const int16x8x2_t zip16_4 = vzipq_s16(kernel.packet[6], kernel.packet[7]);
+ const uint8x8_t a = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[2], vdup_n_u32(kernel.packet[0]), 1));
+ const uint8x8_t b = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[3], vdup_n_u32(kernel.packet[1]), 1));
- const uint32x4x2_t zip32_1 = vzipq_u32(vreinterpretq_u32_s16(zip16_1.val[0]), vreinterpretq_u32_s16(zip16_2.val[0]));
- const uint32x4x2_t zip32_2 = vzipq_u32(vreinterpretq_u32_s16(zip16_1.val[1]), vreinterpretq_u32_s16(zip16_2.val[1]));
- const uint32x4x2_t zip32_3 = vzipq_u32(vreinterpretq_u32_s16(zip16_3.val[0]), vreinterpretq_u32_s16(zip16_4.val[0]));
- const uint32x4x2_t zip32_4 = vzipq_u32(vreinterpretq_u32_s16(zip16_3.val[1]), vreinterpretq_u32_s16(zip16_4.val[1]));
+ const uint8x8x2_t zip8 = vzip_u8(a,b);
+ const uint16x4x2_t zip16 = vzip_u16(vreinterpret_u16_u8(zip8.val[0]), vreinterpret_u16_u8(zip8.val[1]));
- kernel.packet[0] = vreinterpretq_s16_u32(vcombine_u32(vget_low_u32(zip32_1.val[0]), vget_low_u32(zip32_3.val[0])));
- kernel.packet[1] = vreinterpretq_s16_u32(vcombine_u32(vget_high_u32(zip32_1.val[0]), vget_high_u32(zip32_3.val[0])));
- kernel.packet[2] = vreinterpretq_s16_u32(vcombine_u32(vget_low_u32(zip32_1.val[1]), vget_low_u32(zip32_3.val[1])));
- kernel.packet[3] = vreinterpretq_s16_u32(vcombine_u32(vget_high_u32(zip32_1.val[1]), vget_high_u32(zip32_3.val[1])));
- kernel.packet[4] = vreinterpretq_s16_u32(vcombine_u32(vget_low_u32(zip32_2.val[0]), vget_low_u32(zip32_4.val[0])));
- kernel.packet[5] = vreinterpretq_s16_u32(vcombine_u32(vget_high_u32(zip32_2.val[0]), vget_high_u32(zip32_4.val[0])));
- kernel.packet[6] = vreinterpretq_s16_u32(vcombine_u32(vget_low_u32(zip32_2.val[1]), vget_low_u32(zip32_4.val[1])));
- kernel.packet[7] = vreinterpretq_s16_u32(vcombine_u32(vget_high_u32(zip32_2.val[1]), vget_high_u32(zip32_4.val[1])));
+ kernel.packet[0] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 0);
+ kernel.packet[1] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 1);
+ kernel.packet[2] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 0);
+ kernel.packet[3] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 1);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4us, 4>& kernel)
-{
- const uint16x4x2_t zip16_1 = vzip_u16(kernel.packet[0], kernel.packet[1]);
- const uint16x4x2_t zip16_2 = vzip_u16(kernel.packet[2], kernel.packet[3]);
-
- const uint32x2x2_t zip32_1 = vzip_u32(vreinterpret_u32_u16(zip16_1.val[0]), vreinterpret_u32_u16(zip16_2.val[0]));
- const uint32x2x2_t zip32_2 = vzip_u32(vreinterpret_u32_u16(zip16_1.val[1]), vreinterpret_u32_u16(zip16_2.val[1]));
-
- kernel.packet[0] = vreinterpret_u16_u32(zip32_1.val[0]);
- kernel.packet[1] = vreinterpret_u16_u32(zip32_1.val[1]);
- kernel.packet[2] = vreinterpret_u16_u32(zip32_2.val[0]);
- kernel.packet[3] = vreinterpret_u16_u32(zip32_2.val[1]);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8uc, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8uc, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 16>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8us, 8>& kernel)
-{
- const uint16x8x2_t zip16_1 = vzipq_u16(kernel.packet[0], kernel.packet[1]);
- const uint16x8x2_t zip16_2 = vzipq_u16(kernel.packet[2], kernel.packet[3]);
- const uint16x8x2_t zip16_3 = vzipq_u16(kernel.packet[4], kernel.packet[5]);
- const uint16x8x2_t zip16_4 = vzipq_u16(kernel.packet[6], kernel.packet[7]);
-
- const uint32x4x2_t zip32_1 = vzipq_u32(vreinterpretq_u32_u16(zip16_1.val[0]), vreinterpretq_u32_u16(zip16_2.val[0]));
- const uint32x4x2_t zip32_2 = vzipq_u32(vreinterpretq_u32_u16(zip16_1.val[1]), vreinterpretq_u32_u16(zip16_2.val[1]));
- const uint32x4x2_t zip32_3 = vzipq_u32(vreinterpretq_u32_u16(zip16_3.val[0]), vreinterpretq_u32_u16(zip16_4.val[0]));
- const uint32x4x2_t zip32_4 = vzipq_u32(vreinterpretq_u32_u16(zip16_3.val[1]), vreinterpretq_u32_u16(zip16_4.val[1]));
- kernel.packet[0] = vreinterpretq_u16_u32(vcombine_u32(vget_low_u32(zip32_1.val[0]), vget_low_u32(zip32_3.val[0])));
- kernel.packet[1] = vreinterpretq_u16_u32(vcombine_u32(vget_high_u32(zip32_1.val[0]), vget_high_u32(zip32_3.val[0])));
- kernel.packet[2] = vreinterpretq_u16_u32(vcombine_u32(vget_low_u32(zip32_1.val[1]), vget_low_u32(zip32_3.val[1])));
- kernel.packet[3] = vreinterpretq_u16_u32(vcombine_u32(vget_high_u32(zip32_1.val[1]), vget_high_u32(zip32_3.val[1])));
- kernel.packet[4] = vreinterpretq_u16_u32(vcombine_u32(vget_low_u32(zip32_2.val[0]), vget_low_u32(zip32_4.val[0])));
- kernel.packet[5] = vreinterpretq_u16_u32(vcombine_u32(vget_high_u32(zip32_2.val[0]), vget_high_u32(zip32_4.val[0])));
- kernel.packet[6] = vreinterpretq_u16_u32(vcombine_u32(vget_low_u32(zip32_2.val[1]), vget_low_u32(zip32_4.val[1])));
- kernel.packet[7] = vreinterpretq_u16_u32(vcombine_u32(vget_high_u32(zip32_2.val[1]), vget_high_u32(zip32_4.val[1])));
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4s, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2i, 2>& kernel)
-{
- const int32x2x2_t z = vzip_s32(kernel.packet[0], kernel.packet[1]);
- kernel.packet[0] = z.val[0];
- kernel.packet[1] = z.val[1];
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8s, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8s, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4i, 4>& kernel)
-{
- const int32x4x2_t tmp1 = vzipq_s32(kernel.packet[0], kernel.packet[1]);
- const int32x4x2_t tmp2 = vzipq_s32(kernel.packet[2], kernel.packet[3]);
- kernel.packet[0] = vcombine_s32(vget_low_s32(tmp1.val[0]), vget_low_s32(tmp2.val[0]));
- kernel.packet[1] = vcombine_s32(vget_high_s32(tmp1.val[0]), vget_high_s32(tmp2.val[0]));
- kernel.packet[2] = vcombine_s32(vget_low_s32(tmp1.val[1]), vget_low_s32(tmp2.val[1]));
- kernel.packet[3] = vcombine_s32(vget_high_s32(tmp1.val[1]), vget_high_s32(tmp2.val[1]));
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4us, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2ui, 2>& kernel)
-{
- const uint32x2x2_t z = vzip_u32(kernel.packet[0], kernel.packet[1]);
- kernel.packet[0] = z.val[0];
- kernel.packet[1] = z.val[1];
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8us, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8us, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4ui, 4>& kernel)
-{
- const uint32x4x2_t tmp1 = vzipq_u32(kernel.packet[0], kernel.packet[1]);
- const uint32x4x2_t tmp2 = vzipq_u32(kernel.packet[2], kernel.packet[3]);
- kernel.packet[0] = vcombine_u32(vget_low_u32(tmp1.val[0]), vget_low_u32(tmp2.val[0]));
- kernel.packet[1] = vcombine_u32(vget_high_u32(tmp1.val[0]), vget_high_u32(tmp2.val[0]));
- kernel.packet[2] = vcombine_u32(vget_low_u32(tmp1.val[1]), vget_low_u32(tmp2.val[1]));
- kernel.packet[3] = vcombine_u32(vget_high_u32(tmp1.val[1]), vget_high_u32(tmp2.val[1]));
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2i, 2>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4i, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2ui, 2>& kernel) {
+ detail::zip_in_place(kernel.packet[0], kernel.packet[1]);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4ui, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
ptranspose(PacketBlock<Packet2l, 2>& kernel)
{
#if EIGEN_ARCH_ARM64
const int64x2_t tmp1 = vzip1q_s64(kernel.packet[0], kernel.packet[1]);
- const int64x2_t tmp2 = vzip2q_s64(kernel.packet[0], kernel.packet[1]);
-
+ kernel.packet[1] = vzip2q_s64(kernel.packet[0], kernel.packet[1]);
kernel.packet[0] = tmp1;
- kernel.packet[1] = tmp2;
#else
const int64x1_t tmp[2][2] = {
{ vget_low_s64(kernel.packet[0]), vget_high_s64(kernel.packet[0]) },
@@ -3135,10 +3048,8 @@ ptranspose(PacketBlock<Packet2ul, 2>& kernel)
{
#if EIGEN_ARCH_ARM64
const uint64x2_t tmp1 = vzip1q_u64(kernel.packet[0], kernel.packet[1]);
- const uint64x2_t tmp2 = vzip2q_u64(kernel.packet[0], kernel.packet[1]);
-
+ kernel.packet[1] = vzip2q_u64(kernel.packet[0], kernel.packet[1]);
kernel.packet[0] = tmp1;
- kernel.packet[1] = tmp2;
#else
const uint64x1_t tmp[2][2] = {
{ vget_low_u64(kernel.packet[0]), vget_high_u64(kernel.packet[0]) },
@@ -3468,6 +3379,15 @@ template<> struct unpacket_traits<Packet4bf>
};
};
+namespace detail {
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4bf>(Packet4bf& p1, Packet4bf& p2) {
+ const uint16x4x2_t tmp = vzip_u16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+} // namespace detail
+
EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p)
{
// See the scalar implemention in BFloat16.h for a comprehensible explanation
@@ -3674,16 +3594,7 @@ template<> EIGEN_STRONG_INLINE Packet4bf preverse<Packet4bf>(const Packet4bf& a)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4bf, 4>& kernel)
{
- PacketBlock<Packet4us, 4> k;
- k.packet[0] = kernel.packet[0];
- k.packet[1] = kernel.packet[1];
- k.packet[2] = kernel.packet[2];
- k.packet[3] = kernel.packet[3];
- ptranspose(k);
- kernel.packet[0] = k.packet[0];
- kernel.packet[1] = k.packet[1];
- kernel.packet[2] = k.packet[2];
- kernel.packet[3] = k.packet[3];
+ detail::ptranspose_impl(kernel);
}
template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index c81ca63c4..121ec7283 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -546,22 +546,24 @@ void packetmath() {
}
}
- const int m_size = PacketSize < 4 ? 1 : 4;
- internal::PacketBlock<Packet, m_size> kernel2;
- for (int i = 0; i < m_size; ++i) {
- kernel2.packet[i] = internal::pload<Packet>(data1 + i * PacketSize);
- }
- ptranspose(kernel2);
- int data_counter = 0;
- for (int i = 0; i < PacketSize; ++i) {
- for (int j = 0; j < m_size; ++j) {
- data2[data_counter++] = data1[j*PacketSize + i];
+ // GeneralBlockPanelKernel also checks PacketBlock<Packet,(PacketSize%4)==0?4:PacketSize>;
+ if (PacketSize > 4 && PacketSize % 4 == 0) {
+ internal::PacketBlock<Packet, PacketSize%4==0?4:PacketSize> kernel2;
+ for (int i = 0; i < 4; ++i) {
+ kernel2.packet[i] = internal::pload<Packet>(data1 + i * PacketSize);
}
- }
- for (int i = 0; i < m_size; ++i) {
- internal::pstore(data3, kernel2.packet[i]);
- for (int j = 0; j < PacketSize; ++j) {
- VERIFY(test::isApproxAbs(data3[j], data2[i*PacketSize + j], refvalue) && "ptranspose");
+ ptranspose(kernel2);
+ int data_counter = 0;
+ for (int i = 0; i < PacketSize; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ data2[data_counter++] = data1[j*PacketSize + i];
+ }
+ }
+ for (int i = 0; i < 4; ++i) {
+ internal::pstore(data3, kernel2.packet[i]);
+ for (int j = 0; j < PacketSize; ++j) {
+ VERIFY(test::isApproxAbs(data3[j], data2[i*PacketSize + j], refvalue) && "ptranspose");
+ }
}
}