// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2020, Arm Limited and Contributors // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_PACKET_MATH_SVE_H #define EIGEN_PACKET_MATH_SVE_H namespace Eigen { namespace internal { #ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD #define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 #endif #ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD #define EIGEN_HAS_SINGLE_INSTRUCTION_MADD #endif #define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 template struct sve_packet_size_selector { enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) }; }; /********************************* int32 **************************************/ typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL))); template <> struct packet_traits : default_packet_traits { typedef PacketXi type; typedef PacketXi half; // Half not implemented yet enum { Vectorizable = 1, AlignedOnScalar = 1, size = sve_packet_size_selector::size, HasHalfPacket = 0, HasAdd = 1, HasSub = 1, HasShift = 1, HasMul = 1, HasNegate = 1, HasAbs = 1, HasArg = 0, HasAbs2 = 1, HasMin = 1, HasMax = 1, HasConj = 1, HasSetLinear = 0, HasBlend = 0, HasReduxp = 0 // Not implemented in SVE }; }; template <> struct unpacket_traits { typedef numext::int32_t type; typedef PacketXi half; // Half not yet implemented enum { size = sve_packet_size_selector::size, alignment = Aligned64, vectorizable = true, masked_load_available = false, masked_store_available = false }; }; template <> EIGEN_STRONG_INLINE void prefetch(const numext::int32_t* addr) { svprfw(svptrue_b32(), addr, SV_PLDL1KEEP); } template <> EIGEN_STRONG_INLINE PacketXi pset1(const numext::int32_t& from) { return svdup_n_s32(from); } template <> EIGEN_STRONG_INLINE PacketXi plset(const numext::int32_t& a) { numext::int32_t c[packet_traits::size]; for (int i = 0; i < packet_traits::size; i++) c[i] = i; return svadd_s32_z(svptrue_b32(), pset1(a), svld1_s32(svptrue_b32(), c)); } template <> EIGEN_STRONG_INLINE PacketXi padd(const PacketXi& a, const PacketXi& b) { return svadd_s32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXi psub(const PacketXi& a, const PacketXi& b) { return svsub_s32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) { return svneg_s32_z(svptrue_b32(), a); } template <> EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a) { return a; } template <> EIGEN_STRONG_INLINE PacketXi pmul(const PacketXi& a, const PacketXi& b) { return svmul_s32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXi pdiv(const PacketXi& a, const PacketXi& b) { return svdiv_s32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c) { return svmla_s32_z(svptrue_b32(), c, a, b); } template <> EIGEN_STRONG_INLINE PacketXi pmin(const PacketXi& a, const PacketXi& b) { return svmin_s32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXi pmax(const PacketXi& a, const PacketXi& b) { return svmax_s32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXi pcmp_le(const PacketXi& a, const PacketXi& b) { return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); } template <> EIGEN_STRONG_INLINE PacketXi pcmp_lt(const PacketXi& a, const PacketXi& b) { return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); } template <> EIGEN_STRONG_INLINE PacketXi pcmp_eq(const PacketXi& a, const PacketXi& b) { return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu); } template <> EIGEN_STRONG_INLINE PacketXi ptrue(const PacketXi& /*a*/) { return svdup_n_s32_z(svptrue_b32(), 0xffffffffu); } template <> EIGEN_STRONG_INLINE PacketXi pzero(const PacketXi& /*a*/) { return svdup_n_s32_z(svptrue_b32(), 0); } template <> EIGEN_STRONG_INLINE PacketXi pand(const PacketXi& a, const PacketXi& b) { return svand_s32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXi por(const PacketXi& a, const PacketXi& b) { return svorr_s32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXi pxor(const PacketXi& a, const PacketXi& b) { return sveor_s32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXi pandnot(const PacketXi& a, const PacketXi& b) { return svbic_s32_z(svptrue_b32(), a, b); } template EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a) { return svasrd_n_s32_z(svptrue_b32(), a, N); } template EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a) { return svreinterpret_s32_u32(svlsr_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), svdup_n_u32_z(svptrue_b32(), N))); } template EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a) { return svlsl_s32_z(svptrue_b32(), a, svdup_n_u32_z(svptrue_b32(), N)); } template <> EIGEN_STRONG_INLINE PacketXi pload(const numext::int32_t* from) { EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from); } template <> EIGEN_STRONG_INLINE PacketXi ploadu(const numext::int32_t* from) { EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from); } template <> EIGEN_STRONG_INLINE PacketXi ploaddup(const numext::int32_t* from) { svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} return svld1_gather_u32index_s32(svptrue_b32(), from, indices); } template <> EIGEN_STRONG_INLINE PacketXi ploadquad(const numext::int32_t* from) { svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...} return svld1_gather_u32index_s32(svptrue_b32(), from, indices); } template <> EIGEN_STRONG_INLINE void pstore(numext::int32_t* to, const PacketXi& from) { EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from); } template <> EIGEN_STRONG_INLINE void pstoreu(numext::int32_t* to, const PacketXi& from) { EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from); } template <> EIGEN_DEVICE_FUNC inline PacketXi pgather(const numext::int32_t* from, Index stride) { // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} svint32_t indices = svindex_s32(0, stride); return svld1_gather_s32index_s32(svptrue_b32(), from, indices); } template <> EIGEN_DEVICE_FUNC inline void pscatter(numext::int32_t* to, const PacketXi& from, Index stride) { // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} svint32_t indices = svindex_s32(0, stride); svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from); } template <> EIGEN_STRONG_INLINE numext::int32_t pfirst(const PacketXi& a) { // svlasta returns the first element if all predicate bits are 0 return svlasta_s32(svpfalse_b(), a); } template <> EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) { return svrev_s32(a); } template <> EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a) { return svabs_s32_z(svptrue_b32(), a); } template <> EIGEN_STRONG_INLINE numext::int32_t predux(const PacketXi& a) { return static_cast(svaddv_s32(svptrue_b32(), a)); } template <> EIGEN_STRONG_INLINE numext::int32_t predux_mul(const PacketXi& a) { EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); // Multiply the vector by its reverse svint32_t prod = svmul_s32_z(svptrue_b32(), a, svrev_s32(a)); svint32_t half_prod; // Extract the high half of the vector. Depending on the VL more reductions need to be done if (EIGEN_ARM64_SVE_VL >= 2048) { half_prod = svtbl_s32(prod, svindex_u32(32, 1)); prod = svmul_s32_z(svptrue_b32(), prod, half_prod); } if (EIGEN_ARM64_SVE_VL >= 1024) { half_prod = svtbl_s32(prod, svindex_u32(16, 1)); prod = svmul_s32_z(svptrue_b32(), prod, half_prod); } if (EIGEN_ARM64_SVE_VL >= 512) { half_prod = svtbl_s32(prod, svindex_u32(8, 1)); prod = svmul_s32_z(svptrue_b32(), prod, half_prod); } if (EIGEN_ARM64_SVE_VL >= 256) { half_prod = svtbl_s32(prod, svindex_u32(4, 1)); prod = svmul_s32_z(svptrue_b32(), prod, half_prod); } // Last reduction half_prod = svtbl_s32(prod, svindex_u32(2, 1)); prod = svmul_s32_z(svptrue_b32(), prod, half_prod); // The reduction is done to the first element. return pfirst(prod); } template <> EIGEN_STRONG_INLINE numext::int32_t predux_min(const PacketXi& a) { return svminv_s32(svptrue_b32(), a); } template <> EIGEN_STRONG_INLINE numext::int32_t predux_max(const PacketXi& a) { return svmaxv_s32(svptrue_b32(), a); } template EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { int buffer[packet_traits::size * N] = {0}; int i = 0; PacketXi stride_index = svindex_s32(0, N); for (i = 0; i < N; i++) { svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]); } for (i = 0; i < N; i++) { kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits::size); } } /********************************* float32 ************************************/ typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL))); template <> struct packet_traits : default_packet_traits { typedef PacketXf type; typedef PacketXf half; enum { Vectorizable = 1, AlignedOnScalar = 1, size = sve_packet_size_selector::size, HasHalfPacket = 0, HasAdd = 1, HasSub = 1, HasShift = 1, HasMul = 1, HasNegate = 1, HasAbs = 1, HasArg = 0, HasAbs2 = 1, HasMin = 1, HasMax = 1, HasConj = 1, HasSetLinear = 0, HasBlend = 0, HasReduxp = 0, // Not implemented in SVE HasDiv = 1, HasFloor = 1, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, HasLog = 1, HasExp = 1, HasSqrt = 0, HasTanh = EIGEN_FAST_MATH, HasErf = EIGEN_FAST_MATH }; }; template <> struct unpacket_traits { typedef float type; typedef PacketXf half; // Half not yet implemented typedef PacketXi integer_packet; enum { size = sve_packet_size_selector::size, alignment = Aligned64, vectorizable = true, masked_load_available = false, masked_store_available = false }; }; template <> EIGEN_STRONG_INLINE PacketXf pset1(const float& from) { return svdup_n_f32(from); } template <> EIGEN_STRONG_INLINE PacketXf pset1frombits(numext::uint32_t from) { return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from)); } template <> EIGEN_STRONG_INLINE PacketXf plset(const float& a) { float c[packet_traits::size]; for (int i = 0; i < packet_traits::size; i++) c[i] = i; return svadd_f32_z(svptrue_b32(), pset1(a), svld1_f32(svptrue_b32(), c)); } template <> EIGEN_STRONG_INLINE PacketXf padd(const PacketXf& a, const PacketXf& b) { return svadd_f32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXf psub(const PacketXf& a, const PacketXf& b) { return svsub_f32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) { return svneg_f32_z(svptrue_b32(), a); } template <> EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a) { return a; } template <> EIGEN_STRONG_INLINE PacketXf pmul(const PacketXf& a, const PacketXf& b) { return svmul_f32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXf pdiv(const PacketXf& a, const PacketXf& b) { return svdiv_f32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) { return svmla_f32_z(svptrue_b32(), c, a, b); } template <> EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) { return svmin_f32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) { return pmin(a, b); } template <> EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) { return svminnm_f32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) { return svmax_f32_z(svptrue_b32(), a, b); } template <> EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) { return pmax(a, b); } template <> EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) { return svmaxnm_f32_z(svptrue_b32(), a, b); } // Float comparisons in SVE return svbool (predicate). Use svdup to set active // lanes to 1 (0xffffffffu) and inactive lanes to 0. template <> EIGEN_STRONG_INLINE PacketXf pcmp_le(const PacketXf& a, const PacketXf& b) { return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); } template <> EIGEN_STRONG_INLINE PacketXf pcmp_lt(const PacketXf& a, const PacketXf& b) { return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); } template <> EIGEN_STRONG_INLINE PacketXf pcmp_eq(const PacketXf& a, const PacketXf& b) { return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu)); } // Do a predicate inverse (svnot_b_z) on the predicate resulted from the // greater/equal comparison (svcmpge_f32). Then fill a float vector with the // active elements. template <> EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan(const PacketXf& a, const PacketXf& b) { return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu)); } template <> EIGEN_STRONG_INLINE PacketXf pfloor(const PacketXf& a) { return svrintm_f32_z(svptrue_b32(), a); } template <> EIGEN_STRONG_INLINE PacketXf ptrue(const PacketXf& /*a*/) { return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu)); } // Logical Operations are not supported for float, so reinterpret casts template <> EIGEN_STRONG_INLINE PacketXf pand(const PacketXf& a, const PacketXf& b) { return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); } template <> EIGEN_STRONG_INLINE PacketXf por(const PacketXf& a, const PacketXf& b) { return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); } template <> EIGEN_STRONG_INLINE PacketXf pxor(const PacketXf& a, const PacketXf& b) { return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); } template <> EIGEN_STRONG_INLINE PacketXf pandnot(const PacketXf& a, const PacketXf& b) { return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); } template <> EIGEN_STRONG_INLINE PacketXf pload(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from); } template <> EIGEN_STRONG_INLINE PacketXf ploadu(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from); } template <> EIGEN_STRONG_INLINE PacketXf ploaddup(const float* from) { svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} return svld1_gather_u32index_f32(svptrue_b32(), from, indices); } template <> EIGEN_STRONG_INLINE PacketXf ploadquad(const float* from) { svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...} return svld1_gather_u32index_f32(svptrue_b32(), from, indices); } template <> EIGEN_STRONG_INLINE void pstore(float* to, const PacketXf& from) { EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from); } template <> EIGEN_STRONG_INLINE void pstoreu(float* to, const PacketXf& from) { EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from); } template <> EIGEN_DEVICE_FUNC inline PacketXf pgather(const float* from, Index stride) { // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} svint32_t indices = svindex_s32(0, stride); return svld1_gather_s32index_f32(svptrue_b32(), from, indices); } template <> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const PacketXf& from, Index stride) { // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} svint32_t indices = svindex_s32(0, stride); svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from); } template <> EIGEN_STRONG_INLINE float pfirst(const PacketXf& a) { // svlasta returns the first element if all predicate bits are 0 return svlasta_f32(svpfalse_b(), a); } template <> EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) { return svrev_f32(a); } template <> EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) { return svabs_f32_z(svptrue_b32(), a); } // TODO(tellenbach): Should this go into MathFunctions.h? If so, change for // all vector extensions and the generic version. template <> EIGEN_STRONG_INLINE PacketXf pfrexp(const PacketXf& a, PacketXf& exponent) { return pfrexp_generic(a, exponent); } template <> EIGEN_STRONG_INLINE float predux(const PacketXf& a) { return svaddv_f32(svptrue_b32(), a); } // Other reduction functions: // mul // Only works for SVE Vls multiple of 128 template <> EIGEN_STRONG_INLINE float predux_mul(const PacketXf& a) { EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); // Multiply the vector by its reverse svfloat32_t prod = svmul_f32_z(svptrue_b32(), a, svrev_f32(a)); svfloat32_t half_prod; // Extract the high half of the vector. Depending on the VL more reductions need to be done if (EIGEN_ARM64_SVE_VL >= 2048) { half_prod = svtbl_f32(prod, svindex_u32(32, 1)); prod = svmul_f32_z(svptrue_b32(), prod, half_prod); } if (EIGEN_ARM64_SVE_VL >= 1024) { half_prod = svtbl_f32(prod, svindex_u32(16, 1)); prod = svmul_f32_z(svptrue_b32(), prod, half_prod); } if (EIGEN_ARM64_SVE_VL >= 512) { half_prod = svtbl_f32(prod, svindex_u32(8, 1)); prod = svmul_f32_z(svptrue_b32(), prod, half_prod); } if (EIGEN_ARM64_SVE_VL >= 256) { half_prod = svtbl_f32(prod, svindex_u32(4, 1)); prod = svmul_f32_z(svptrue_b32(), prod, half_prod); } // Last reduction half_prod = svtbl_f32(prod, svindex_u32(2, 1)); prod = svmul_f32_z(svptrue_b32(), prod, half_prod); // The reduction is done to the first element. return pfirst(prod); } template <> EIGEN_STRONG_INLINE float predux_min(const PacketXf& a) { return svminv_f32(svptrue_b32(), a); } template <> EIGEN_STRONG_INLINE float predux_max(const PacketXf& a) { return svmaxv_f32(svptrue_b32(), a); } template EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { float buffer[packet_traits::size * N] = {0}; int i = 0; PacketXi stride_index = svindex_s32(0, N); for (i = 0; i < N; i++) { svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]); } for (i = 0; i < N; i++) { kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits::size); } } template<> EIGEN_STRONG_INLINE PacketXf pldexp(const PacketXf& a, const PacketXf& exponent) { return pldexp_generic(a, exponent); } } // namespace internal } // namespace Eigen #endif // EIGEN_PACKET_MATH_SVE_H