From dba753a986b527a17c8cc62474d0487aec7c2b36 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Mon, 24 May 2021 21:34:35 -0700 Subject: Add missing NEON ptranspose implementations. Unified implementation using only `vzip`. --- Eigen/src/Core/arch/NEON/PacketMath.h | 507 ++++++++++++++-------------------- test/packetmath.cpp | 32 ++- 2 files changed, 226 insertions(+), 313 deletions(-) diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 2b48570d1..73a35c570 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -2774,352 +2774,265 @@ template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x) return vget_lane_u32(vpmax_u32(tmp, tmp), 0); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const float32x2x2_t z = vzip_f32(kernel.packet[0], kernel.packet[1]); - kernel.packet[0] = z.val[0]; - kernel.packet[1] = z.val[1]; +// Helpers for ptranspose. +namespace detail { + +template +void zip_in_place(Packet& p1, Packet& p2); + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet2f& p1, Packet2f& p2) { + const float32x2x2_t tmp = vzip_f32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const float32x4x2_t tmp1 = vzipq_f32(kernel.packet[0], kernel.packet[1]); - const float32x4x2_t tmp2 = vzipq_f32(kernel.packet[2], kernel.packet[3]); - kernel.packet[0] = vcombine_f32(vget_low_f32(tmp1.val[0]), vget_low_f32(tmp2.val[0])); - kernel.packet[1] = vcombine_f32(vget_high_f32(tmp1.val[0]), vget_high_f32(tmp2.val[0])); - kernel.packet[2] = vcombine_f32(vget_low_f32(tmp1.val[1]), vget_low_f32(tmp2.val[1])); - kernel.packet[3] = vcombine_f32(vget_high_f32(tmp1.val[1]), vget_high_f32(tmp2.val[1])); +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4f& p1, Packet4f& p2) { + const float32x4x2_t tmp = vzipq_f32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const int8x8_t a = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[2], vdup_n_s32(kernel.packet[0]), 1)); - const int8x8_t b = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[3], vdup_n_s32(kernel.packet[1]), 1)); - const int8x8x2_t zip8 = vzip_s8(a,b); - const int16x4x2_t zip16 = vzip_s16(vreinterpret_s16_s8(zip8.val[0]), vreinterpret_s16_s8(zip8.val[1])); +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet8c& p1, Packet8c& p2) { + const int8x8x2_t tmp = vzip_s8(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} - kernel.packet[0] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 0); - kernel.packet[1] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 1); - kernel.packet[2] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 0); - kernel.packet[3] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 1); +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet16c& p1, Packet16c& p2) { + const int8x16x2_t tmp = vzipq_s8(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - int8x8x2_t zip8[4]; - uint16x4x2_t zip16[4]; - EIGEN_UNROLL_LOOP - for (int i = 0; i != 4; i++) - zip8[i] = vzip_s8(kernel.packet[i*2], kernel.packet[i*2+1]); +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet8uc& p1, Packet8uc& p2) { + const uint8x8x2_t tmp = vzip_u8(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} - EIGEN_UNROLL_LOOP - for (int i = 0; i != 2; i++) - { - EIGEN_UNROLL_LOOP - for (int j = 0; j != 2; j++) - zip16[i*2+j] = vzip_u16(vreinterpret_u16_s8(zip8[i*2].val[j]), vreinterpret_u16_s8(zip8[i*2+1].val[j])); - } +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet16uc& p1, Packet16uc& p2) { + const uint8x16x2_t tmp = vzipq_u8(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} - EIGEN_UNROLL_LOOP - for (int i = 0; i != 2; i++) - { - EIGEN_UNROLL_LOOP - for (int j = 0; j != 2; j++) - { - const uint32x2x2_t z = vzip_u32(vreinterpret_u32_u16(zip16[i].val[j]), vreinterpret_u32_u16(zip16[i+2].val[j])); - EIGEN_UNROLL_LOOP - for (int k = 0; k != 2; k++) - kernel.packet[i*4+j*2+k] = vreinterpret_s8_u32(z.val[k]); - } - } +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet2i& p1, Packet2i& p2) { + const int32x2x2_t tmp = vzip_s32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - int8x16x2_t zip8[8]; - uint16x8x2_t zip16[8]; - uint32x4x2_t zip32[8]; - EIGEN_UNROLL_LOOP - for (int i = 0; i != 8; i++) - zip8[i] = vzipq_s8(kernel.packet[i*2], kernel.packet[i*2+1]); +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4i& p1, Packet4i& p2) { + const int32x4x2_t tmp = vzipq_s32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} - EIGEN_UNROLL_LOOP - for (int i = 0; i != 4; i++) - { - EIGEN_UNROLL_LOOP - for (int j = 0; j != 2; j++) - { - zip16[i*2+j] = vzipq_u16(vreinterpretq_u16_s8(zip8[i*2].val[j]), - vreinterpretq_u16_s8(zip8[i*2+1].val[j])); - } - } +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet2ui& p1, Packet2ui& p2) { + const uint32x2x2_t tmp = vzip_u32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} - EIGEN_UNROLL_LOOP - for (int i = 0; i != 2; i++) - { - EIGEN_UNROLL_LOOP - for (int j = 0; j != 2; j++) - { - EIGEN_UNROLL_LOOP - for (int k = 0; k != 2; k++) - zip32[i*4+j*2+k] = vzipq_u32(vreinterpretq_u32_u16(zip16[i*4+j].val[k]), - vreinterpretq_u32_u16(zip16[i*4+j+2].val[k])); - } - } +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4ui& p1, Packet4ui& p2) { + const uint32x4x2_t tmp = vzipq_u32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} - EIGEN_UNROLL_LOOP - for (int i = 0; i != 4; i++) - { - EIGEN_UNROLL_LOOP - for (int j = 0; j != 2; j++) - { - kernel.packet[i*4+j*2] = vreinterpretq_s8_u32(vcombine_u32(vget_low_u32(zip32[i].val[j]), - vget_low_u32(zip32[i+4].val[j]))); - kernel.packet[i*4+j*2+1] = vreinterpretq_s8_u32(vcombine_u32(vget_high_u32(zip32[i].val[j]), - vget_high_u32(zip32[i+4].val[j]))); - } - } +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4s& p1, Packet4s& p2) { + const int16x4x2_t tmp = vzip_s16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const uint8x8_t a = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[2], vdup_n_u32(kernel.packet[0]), 1)); - const uint8x8_t b = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[3], vdup_n_u32(kernel.packet[1]), 1)); - const uint8x8x2_t zip8 = vzip_u8(a,b); - const uint16x4x2_t zip16 = vzip_u16(vreinterpret_u16_u8(zip8.val[0]), vreinterpret_u16_u8(zip8.val[1])); +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet8s& p1, Packet8s& p2) { + const int16x8x2_t tmp = vzipq_s16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} - kernel.packet[0] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 0); - kernel.packet[1] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 1); - kernel.packet[2] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 0); - kernel.packet[3] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 1); +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4us& p1, Packet4us& p2) { + const uint16x4x2_t tmp = vzip_u16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - uint8x8x2_t zip8[4]; - uint16x4x2_t zip16[4]; - EIGEN_UNROLL_LOOP - for (int i = 0; i != 4; i++) - zip8[i] = vzip_u8(kernel.packet[i*2], kernel.packet[i*2+1]); +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet8us& p1, Packet8us& p2) { + const uint16x8x2_t tmp = vzipq_u16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} - EIGEN_UNROLL_LOOP - for (int i = 0; i != 2; i++) - { - EIGEN_UNROLL_LOOP - for (int j = 0; j != 2; j++) - zip16[i*2+j] = vzip_u16(vreinterpret_u16_u8(zip8[i*2].val[j]), vreinterpret_u16_u8(zip8[i*2+1].val[j])); - } +template +EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock& kernel) { + zip_in_place(kernel.packet[0], kernel.packet[1]); +} - EIGEN_UNROLL_LOOP - for (int i = 0; i != 2; i++) - { - EIGEN_UNROLL_LOOP - for (int j = 0; j != 2; j++) - { - const uint32x2x2_t z = vzip_u32(vreinterpret_u32_u16(zip16[i].val[j]), vreinterpret_u32_u16(zip16[i+2].val[j])); - EIGEN_UNROLL_LOOP - for (int k = 0; k != 2; k++) - kernel.packet[i*4+j*2+k] = vreinterpret_u8_u32(z.val[k]); - } - } +template +EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock& kernel) { + zip_in_place(kernel.packet[0], kernel.packet[2]); + zip_in_place(kernel.packet[1], kernel.packet[3]); + zip_in_place(kernel.packet[0], kernel.packet[1]); + zip_in_place(kernel.packet[2], kernel.packet[3]); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - uint8x16x2_t zip8[8]; - uint16x8x2_t zip16[8]; - uint32x4x2_t zip32[8]; - EIGEN_UNROLL_LOOP - for (int i = 0; i != 8; i++) - zip8[i] = vzipq_u8(kernel.packet[i*2], kernel.packet[i*2+1]); +template +EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock& kernel) { + zip_in_place(kernel.packet[0], kernel.packet[4]); + zip_in_place(kernel.packet[1], kernel.packet[5]); + zip_in_place(kernel.packet[2], kernel.packet[6]); + zip_in_place(kernel.packet[3], kernel.packet[7]); - EIGEN_UNROLL_LOOP - for (int i = 0; i != 4; i++) - { - EIGEN_UNROLL_LOOP - for (int j = 0; j != 2; j++) - zip16[i*2+j] = vzipq_u16(vreinterpretq_u16_u8(zip8[i*2].val[j]), - vreinterpretq_u16_u8(zip8[i*2+1].val[j])); - } + zip_in_place(kernel.packet[0], kernel.packet[2]); + zip_in_place(kernel.packet[1], kernel.packet[3]); + zip_in_place(kernel.packet[4], kernel.packet[6]); + zip_in_place(kernel.packet[5], kernel.packet[7]); + + zip_in_place(kernel.packet[0], kernel.packet[1]); + zip_in_place(kernel.packet[2], kernel.packet[3]); + zip_in_place(kernel.packet[4], kernel.packet[5]); + zip_in_place(kernel.packet[6], kernel.packet[7]); +} +template +EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock& kernel) { EIGEN_UNROLL_LOOP - for (int i = 0; i != 2; i++) - { + for (int i=0; i<4; ++i) { + const int m = (1 << i); EIGEN_UNROLL_LOOP - for (int j = 0; j != 2; j++) - { + for (int j=0; j& kernel) -{ - const int16x4x2_t zip16_1 = vzip_s16(kernel.packet[0], kernel.packet[1]); - const int16x4x2_t zip16_2 = vzip_s16(kernel.packet[2], kernel.packet[3]); - const uint32x2x2_t zip32_1 = vzip_u32(vreinterpret_u32_s16(zip16_1.val[0]), vreinterpret_u32_s16(zip16_2.val[0])); - const uint32x2x2_t zip32_2 = vzip_u32(vreinterpret_u32_s16(zip16_1.val[1]), vreinterpret_u32_s16(zip16_2.val[1])); +} // namespace detail - kernel.packet[0] = vreinterpret_s16_u32(zip32_1.val[0]); - kernel.packet[1] = vreinterpret_s16_u32(zip32_1.val[1]); - kernel.packet[2] = vreinterpret_s16_u32(zip32_2.val[0]); - kernel.packet[3] = vreinterpret_s16_u32(zip32_2.val[1]); +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); } - -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const int16x8x2_t zip16_1 = vzipq_s16(kernel.packet[0], kernel.packet[1]); - const int16x8x2_t zip16_2 = vzipq_s16(kernel.packet[2], kernel.packet[3]); - - const uint32x4x2_t zip32_1 = vzipq_u32(vreinterpretq_u32_s16(zip16_1.val[0]), vreinterpretq_u32_s16(zip16_2.val[0])); - const uint32x4x2_t zip32_2 = vzipq_u32(vreinterpretq_u32_s16(zip16_1.val[1]), vreinterpretq_u32_s16(zip16_2.val[1])); - - kernel.packet[0] = vreinterpretq_s16_u32(zip32_1.val[0]); - kernel.packet[1] = vreinterpretq_s16_u32(zip32_1.val[1]); - kernel.packet[2] = vreinterpretq_s16_u32(zip32_2.val[0]); - kernel.packet[3] = vreinterpretq_s16_u32(zip32_2.val[1]); +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { - const int8x16x2_t zip8_1 = vzipq_s8(kernel.packet[0], kernel.packet[1]); - const int8x16x2_t zip8_2 = vzipq_s8(kernel.packet[2], kernel.packet[3]); + const int8x8_t a = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[2], vdup_n_s32(kernel.packet[0]), 1)); + const int8x8_t b = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[3], vdup_n_s32(kernel.packet[1]), 1)); - const int16x8x2_t zip16_1 = vzipq_s16(vreinterpretq_s16_s8(zip8_1.val[0]), vreinterpretq_s16_s8(zip8_2.val[0])); - const int16x8x2_t zip16_2 = vzipq_s16(vreinterpretq_s16_s8(zip8_1.val[1]), vreinterpretq_s16_s8(zip8_2.val[1])); + const int8x8x2_t zip8 = vzip_s8(a,b); + const int16x4x2_t zip16 = vzip_s16(vreinterpret_s16_s8(zip8.val[0]), vreinterpret_s16_s8(zip8.val[1])); - kernel.packet[0] = vreinterpretq_s8_s16(zip16_1.val[0]); - kernel.packet[1] = vreinterpretq_s8_s16(zip16_1.val[1]); - kernel.packet[2] = vreinterpretq_s8_s16(zip16_2.val[0]); - kernel.packet[3] = vreinterpretq_s8_s16(zip16_2.val[1]); + kernel.packet[0] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 0); + kernel.packet[1] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 1); + kernel.packet[2] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 0); + kernel.packet[3] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 1); } - -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const uint8x16x2_t zip8_1 = vzipq_u8(kernel.packet[0], kernel.packet[1]); - const uint8x16x2_t zip8_2 = vzipq_u8(kernel.packet[2], kernel.packet[3]); - - const uint16x8x2_t zip16_1 = vzipq_u16(vreinterpretq_u16_u8(zip8_1.val[0]), vreinterpretq_u16_u8(zip8_2.val[0])); - const uint16x8x2_t zip16_2 = vzipq_u16(vreinterpretq_u16_u8(zip8_1.val[1]), vreinterpretq_u16_u8(zip8_2.val[1])); - - kernel.packet[0] = vreinterpretq_u8_u16(zip16_1.val[0]); - kernel.packet[1] = vreinterpretq_u8_u16(zip16_1.val[1]); - kernel.packet[2] = vreinterpretq_u8_u16(zip16_2.val[0]); - kernel.packet[3] = vreinterpretq_u8_u16(zip16_2.val[1]); +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { - const int16x8x2_t zip16_1 = vzipq_s16(kernel.packet[0], kernel.packet[1]); - const int16x8x2_t zip16_2 = vzipq_s16(kernel.packet[2], kernel.packet[3]); - const int16x8x2_t zip16_3 = vzipq_s16(kernel.packet[4], kernel.packet[5]); - const int16x8x2_t zip16_4 = vzipq_s16(kernel.packet[6], kernel.packet[7]); + const uint8x8_t a = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[2], vdup_n_u32(kernel.packet[0]), 1)); + const uint8x8_t b = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[3], vdup_n_u32(kernel.packet[1]), 1)); - const uint32x4x2_t zip32_1 = vzipq_u32(vreinterpretq_u32_s16(zip16_1.val[0]), vreinterpretq_u32_s16(zip16_2.val[0])); - const uint32x4x2_t zip32_2 = vzipq_u32(vreinterpretq_u32_s16(zip16_1.val[1]), vreinterpretq_u32_s16(zip16_2.val[1])); - const uint32x4x2_t zip32_3 = vzipq_u32(vreinterpretq_u32_s16(zip16_3.val[0]), vreinterpretq_u32_s16(zip16_4.val[0])); - const uint32x4x2_t zip32_4 = vzipq_u32(vreinterpretq_u32_s16(zip16_3.val[1]), vreinterpretq_u32_s16(zip16_4.val[1])); + const uint8x8x2_t zip8 = vzip_u8(a,b); + const uint16x4x2_t zip16 = vzip_u16(vreinterpret_u16_u8(zip8.val[0]), vreinterpret_u16_u8(zip8.val[1])); - kernel.packet[0] = vreinterpretq_s16_u32(vcombine_u32(vget_low_u32(zip32_1.val[0]), vget_low_u32(zip32_3.val[0]))); - kernel.packet[1] = vreinterpretq_s16_u32(vcombine_u32(vget_high_u32(zip32_1.val[0]), vget_high_u32(zip32_3.val[0]))); - kernel.packet[2] = vreinterpretq_s16_u32(vcombine_u32(vget_low_u32(zip32_1.val[1]), vget_low_u32(zip32_3.val[1]))); - kernel.packet[3] = vreinterpretq_s16_u32(vcombine_u32(vget_high_u32(zip32_1.val[1]), vget_high_u32(zip32_3.val[1]))); - kernel.packet[4] = vreinterpretq_s16_u32(vcombine_u32(vget_low_u32(zip32_2.val[0]), vget_low_u32(zip32_4.val[0]))); - kernel.packet[5] = vreinterpretq_s16_u32(vcombine_u32(vget_high_u32(zip32_2.val[0]), vget_high_u32(zip32_4.val[0]))); - kernel.packet[6] = vreinterpretq_s16_u32(vcombine_u32(vget_low_u32(zip32_2.val[1]), vget_low_u32(zip32_4.val[1]))); - kernel.packet[7] = vreinterpretq_s16_u32(vcombine_u32(vget_high_u32(zip32_2.val[1]), vget_high_u32(zip32_4.val[1]))); + kernel.packet[0] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 0); + kernel.packet[1] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 1); + kernel.packet[2] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 0); + kernel.packet[3] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 1); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const uint16x4x2_t zip16_1 = vzip_u16(kernel.packet[0], kernel.packet[1]); - const uint16x4x2_t zip16_2 = vzip_u16(kernel.packet[2], kernel.packet[3]); - - const uint32x2x2_t zip32_1 = vzip_u32(vreinterpret_u32_u16(zip16_1.val[0]), vreinterpret_u32_u16(zip16_2.val[0])); - const uint32x2x2_t zip32_2 = vzip_u32(vreinterpret_u32_u16(zip16_1.val[1]), vreinterpret_u32_u16(zip16_2.val[1])); - - kernel.packet[0] = vreinterpret_u16_u32(zip32_1.val[0]); - kernel.packet[1] = vreinterpret_u16_u32(zip32_1.val[1]); - kernel.packet[2] = vreinterpret_u16_u32(zip32_2.val[0]); - kernel.packet[3] = vreinterpret_u16_u32(zip32_2.val[1]); +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const uint16x8x2_t zip16_1 = vzipq_u16(kernel.packet[0], kernel.packet[1]); - const uint16x8x2_t zip16_2 = vzipq_u16(kernel.packet[2], kernel.packet[3]); - const uint16x8x2_t zip16_3 = vzipq_u16(kernel.packet[4], kernel.packet[5]); - const uint16x8x2_t zip16_4 = vzipq_u16(kernel.packet[6], kernel.packet[7]); - - const uint32x4x2_t zip32_1 = vzipq_u32(vreinterpretq_u32_u16(zip16_1.val[0]), vreinterpretq_u32_u16(zip16_2.val[0])); - const uint32x4x2_t zip32_2 = vzipq_u32(vreinterpretq_u32_u16(zip16_1.val[1]), vreinterpretq_u32_u16(zip16_2.val[1])); - const uint32x4x2_t zip32_3 = vzipq_u32(vreinterpretq_u32_u16(zip16_3.val[0]), vreinterpretq_u32_u16(zip16_4.val[0])); - const uint32x4x2_t zip32_4 = vzipq_u32(vreinterpretq_u32_u16(zip16_3.val[1]), vreinterpretq_u32_u16(zip16_4.val[1])); - kernel.packet[0] = vreinterpretq_u16_u32(vcombine_u32(vget_low_u32(zip32_1.val[0]), vget_low_u32(zip32_3.val[0]))); - kernel.packet[1] = vreinterpretq_u16_u32(vcombine_u32(vget_high_u32(zip32_1.val[0]), vget_high_u32(zip32_3.val[0]))); - kernel.packet[2] = vreinterpretq_u16_u32(vcombine_u32(vget_low_u32(zip32_1.val[1]), vget_low_u32(zip32_3.val[1]))); - kernel.packet[3] = vreinterpretq_u16_u32(vcombine_u32(vget_high_u32(zip32_1.val[1]), vget_high_u32(zip32_3.val[1]))); - kernel.packet[4] = vreinterpretq_u16_u32(vcombine_u32(vget_low_u32(zip32_2.val[0]), vget_low_u32(zip32_4.val[0]))); - kernel.packet[5] = vreinterpretq_u16_u32(vcombine_u32(vget_high_u32(zip32_2.val[0]), vget_high_u32(zip32_4.val[0]))); - kernel.packet[6] = vreinterpretq_u16_u32(vcombine_u32(vget_low_u32(zip32_2.val[1]), vget_low_u32(zip32_4.val[1]))); - kernel.packet[7] = vreinterpretq_u16_u32(vcombine_u32(vget_high_u32(zip32_2.val[1]), vget_high_u32(zip32_4.val[1]))); +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const int32x2x2_t z = vzip_s32(kernel.packet[0], kernel.packet[1]); - kernel.packet[0] = z.val[0]; - kernel.packet[1] = z.val[1]; +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const int32x4x2_t tmp1 = vzipq_s32(kernel.packet[0], kernel.packet[1]); - const int32x4x2_t tmp2 = vzipq_s32(kernel.packet[2], kernel.packet[3]); - kernel.packet[0] = vcombine_s32(vget_low_s32(tmp1.val[0]), vget_low_s32(tmp2.val[0])); - kernel.packet[1] = vcombine_s32(vget_high_s32(tmp1.val[0]), vget_high_s32(tmp2.val[0])); - kernel.packet[2] = vcombine_s32(vget_low_s32(tmp1.val[1]), vget_low_s32(tmp2.val[1])); - kernel.packet[3] = vcombine_s32(vget_high_s32(tmp1.val[1]), vget_high_s32(tmp2.val[1])); +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const uint32x2x2_t z = vzip_u32(kernel.packet[0], kernel.packet[1]); - kernel.packet[0] = z.val[0]; - kernel.packet[1] = z.val[1]; +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); } -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) -{ - const uint32x4x2_t tmp1 = vzipq_u32(kernel.packet[0], kernel.packet[1]); - const uint32x4x2_t tmp2 = vzipq_u32(kernel.packet[2], kernel.packet[3]); - kernel.packet[0] = vcombine_u32(vget_low_u32(tmp1.val[0]), vget_low_u32(tmp2.val[0])); - kernel.packet[1] = vcombine_u32(vget_high_u32(tmp1.val[0]), vget_high_u32(tmp2.val[0])); - kernel.packet[2] = vcombine_u32(vget_low_u32(tmp1.val[1]), vget_low_u32(tmp2.val[1])); - kernel.packet[3] = vcombine_u32(vget_high_u32(tmp1.val[1]), vget_high_u32(tmp2.val[1])); +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); } +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::zip_in_place(kernel.packet[0], kernel.packet[1]); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { #if EIGEN_ARCH_ARM64 const int64x2_t tmp1 = vzip1q_s64(kernel.packet[0], kernel.packet[1]); - const int64x2_t tmp2 = vzip2q_s64(kernel.packet[0], kernel.packet[1]); - + kernel.packet[1] = vzip2q_s64(kernel.packet[0], kernel.packet[1]); kernel.packet[0] = tmp1; - kernel.packet[1] = tmp2; #else const int64x1_t tmp[2][2] = { { vget_low_s64(kernel.packet[0]), vget_high_s64(kernel.packet[0]) }, @@ -3135,10 +3048,8 @@ ptranspose(PacketBlock& kernel) { #if EIGEN_ARCH_ARM64 const uint64x2_t tmp1 = vzip1q_u64(kernel.packet[0], kernel.packet[1]); - const uint64x2_t tmp2 = vzip2q_u64(kernel.packet[0], kernel.packet[1]); - + kernel.packet[1] = vzip2q_u64(kernel.packet[0], kernel.packet[1]); kernel.packet[0] = tmp1; - kernel.packet[1] = tmp2; #else const uint64x1_t tmp[2][2] = { { vget_low_u64(kernel.packet[0]), vget_high_u64(kernel.packet[0]) }, @@ -3468,6 +3379,15 @@ template<> struct unpacket_traits }; }; +namespace detail { +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4bf& p1, Packet4bf& p2) { + const uint16x4x2_t tmp = vzip_u16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} +} // namespace detail + EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p) { // See the scalar implemention in BFloat16.h for a comprehensible explanation @@ -3674,16 +3594,7 @@ template<> EIGEN_STRONG_INLINE Packet4bf preverse(const Packet4bf& a) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { - PacketBlock k; - k.packet[0] = kernel.packet[0]; - k.packet[1] = kernel.packet[1]; - k.packet[2] = kernel.packet[2]; - k.packet[3] = kernel.packet[3]; - ptranspose(k); - kernel.packet[0] = k.packet[0]; - kernel.packet[1] = k.packet[1]; - kernel.packet[2] = k.packet[2]; - kernel.packet[3] = k.packet[3]; + detail::ptranspose_impl(kernel); } template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff(const Packet4bf& a, const Packet4bf& b) diff --git a/test/packetmath.cpp b/test/packetmath.cpp index c81ca63c4..121ec7283 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -546,22 +546,24 @@ void packetmath() { } } - const int m_size = PacketSize < 4 ? 1 : 4; - internal::PacketBlock kernel2; - for (int i = 0; i < m_size; ++i) { - kernel2.packet[i] = internal::pload(data1 + i * PacketSize); - } - ptranspose(kernel2); - int data_counter = 0; - for (int i = 0; i < PacketSize; ++i) { - for (int j = 0; j < m_size; ++j) { - data2[data_counter++] = data1[j*PacketSize + i]; + // GeneralBlockPanelKernel also checks PacketBlock; + if (PacketSize > 4 && PacketSize % 4 == 0) { + internal::PacketBlock kernel2; + for (int i = 0; i < 4; ++i) { + kernel2.packet[i] = internal::pload(data1 + i * PacketSize); } - } - for (int i = 0; i < m_size; ++i) { - internal::pstore(data3, kernel2.packet[i]); - for (int j = 0; j < PacketSize; ++j) { - VERIFY(test::isApproxAbs(data3[j], data2[i*PacketSize + j], refvalue) && "ptranspose"); + ptranspose(kernel2); + int data_counter = 0; + for (int i = 0; i < PacketSize; ++i) { + for (int j = 0; j < 4; ++j) { + data2[data_counter++] = data1[j*PacketSize + i]; + } + } + for (int i = 0; i < 4; ++i) { + internal::pstore(data3, kernel2.packet[i]); + for (int j = 0; j < PacketSize; ++j) { + VERIFY(test::isApproxAbs(data3[j], data2[i*PacketSize + j], refvalue) && "ptranspose"); + } } } -- cgit v1.2.3