From b733b8b680885c0fcdfddea5423171468609b5a6 Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Tue, 10 Mar 2020 20:28:43 +0000 Subject: remove duplicate pset1 for half and add some comments about why we need expose pmul/add/div/min/max on host --- Eigen/src/Core/arch/GPU/PacketMath.h | 895 +++++++++++++++++++++++++++++++--- Eigen/src/Core/arch/GPU/TypeCasting.h | 33 +- 2 files changed, 849 insertions(+), 79 deletions(-) (limited to 'Eigen/src/Core/arch') diff --git a/Eigen/src/Core/arch/GPU/PacketMath.h b/Eigen/src/Core/arch/GPU/PacketMath.h index 9e18c5145..1f6a562c5 100644 --- a/Eigen/src/Core/arch/GPU/PacketMath.h +++ b/Eigen/src/Core/arch/GPU/PacketMath.h @@ -476,24 +476,29 @@ ptranspose(PacketBlock& kernel) { kernel.packet[1].x = tmp; } -#endif +#endif // defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU) -// Packet math for Eigen::half -// Most of the following operations require arch >= 3.0 -#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDACC) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ +// Packet4h2 must be defined in the macro without EIGEN_CUDA_ARCH, meaning +// its corresponding packet_traits must be visible on host. +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDACC)) || \ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIPCC) && defined(EIGEN_HIP_DEVICE_COMPILE)) || \ (defined(EIGEN_HAS_CUDA_FP16) && defined(__clang__) && defined(__CUDA__)) +typedef ulonglong2 Packet4h2; +template<> struct unpacket_traits { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4h2 half; }; +template<> struct is_arithmetic { enum { value = true }; }; + +template<> struct unpacket_traits { typedef Eigen::half type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef half2 half; }; template<> struct is_arithmetic { enum { value = true }; }; template<> struct packet_traits : default_packet_traits { - typedef half2 type; - typedef half2 half; + typedef Packet4h2 type; + typedef Packet4h2 half; enum { Vectorizable = 1, AlignedOnScalar = 1, - size=2, + size=8, HasHalfPacket = 0, HasAdd = 1, HasSub = 1, @@ -508,9 +513,8 @@ template<> struct packet_traits : default_packet_traits }; }; -template<> struct unpacket_traits { typedef Eigen::half type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef half2 half; }; - -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pset1(const Eigen::half& from) { +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pset1(const Eigen::half& from) { #if !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_HIP_DEVICE_COMPILE) half2 r; r.x = from; @@ -521,23 +525,40 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pset1(const Eigen: #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pload(const Eigen::half* from) { +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pset1(const Eigen::half& from) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = pset1(from); + p_alias[1] = pset1(from); + p_alias[2] = pset1(from); + p_alias[3] = pset1(from); + return r; +} + +#if defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIP_DEVICE_COMPILE) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) +namespace { + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pload(const Eigen::half* from) { return *reinterpret_cast(from); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploadu(const Eigen::half* from) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploadu(const Eigen::half* from) { return __halves2half2(from[0], from[1]); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploaddup(const Eigen::half* from) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploaddup(const Eigen::half* from) { return __halves2half2(from[0], from[0]); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const half2& from) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore(Eigen::half* to, + const half2& from) { *reinterpret_cast(to) = from; } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const half2& from) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, + const half2& from) { #if !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_HIP_DEVICE_COMPILE) to[0] = from.x; to[1] = from.y; @@ -547,8 +568,9 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu(Eigen #endif } -template<> - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro(const Eigen::half* from) { + +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro_aligned( + const Eigen::half* from) { #if defined(EIGEN_HIP_DEVICE_COMPILE) @@ -565,8 +587,8 @@ template<> #endif } -template<> -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro(const Eigen::half* from) { +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro_unaligned( + const Eigen::half* from) { #if defined(EIGEN_HIP_DEVICE_COMPILE) @@ -583,20 +605,22 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro(const Ei #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pgather(const Eigen::half* from, Index stride) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pgather(const Eigen::half* from, + Index stride) { return __halves2half2(from[0*stride], from[1*stride]); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(Eigen::half* to, const half2& from, Index stride) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter( + Eigen::half* to, const half2& from, Index stride) { to[stride*0] = __low2half(from); to[stride*1] = __high2half(from); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst(const half2& a) { return __low2half(a); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pabs(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pabs(const half2& a) { half a1 = __low2half(a); half a2 = __high2half(a); half result1 = half_impl::raw_uint16_to_half(a1.x & 0x7FFF); @@ -604,12 +628,12 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pabs(const half2& return __halves2half2(result1, result2); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ptrue(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ptrue(const half2& a) { half true_half = half_impl::raw_uint16_to_half(0xffffu); return pset1(true_half); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pzero(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pzero(const half2& a) { half false_half = half_impl::raw_uint16_to_half(0x0000u); return pset1(false_half); } @@ -624,7 +648,7 @@ ptranspose(PacketBlock& kernel) { kernel.packet[1] = __halves2half2(a2, b2); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset(const Eigen::half& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset(const Eigen::half& a) { #if defined(EIGEN_HIP_DEVICE_COMPILE) return __halves2half2(a, __hadd(a, __float2half(1.0f))); @@ -641,10 +665,9 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset(const Eigen: #endif } -template <> -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect(const half2& mask, - const half2& a, - const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect(const half2& mask, + const half2& a, + const half2& b) { half mask_low = __low2half(mask); half mask_high = __high2half(mask); half result_low = mask_low == half(0) ? __low2half(b) : __low2half(a); @@ -652,9 +675,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect(const half2& mask, return __halves2half2(result_low, result_high); } -template <> -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq(const half2& a, - const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq(const half2& a, + const half2& b) { half true_half = half_impl::raw_uint16_to_half(0xffffu); half false_half = half_impl::raw_uint16_to_half(0x0000u); half a1 = __low2half(a); @@ -666,9 +688,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq(const half2& a, return __halves2half2(eq1, eq2); } -template <> -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt(const half2& a, - const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt(const half2& a, + const half2& b) { half true_half = half_impl::raw_uint16_to_half(0xffffu); half false_half = half_impl::raw_uint16_to_half(0x0000u); half a1 = __low2half(a); @@ -680,9 +701,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt(const half2& a, return __halves2half2(eq1, eq2); } -template <> -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand(const half2& a, - const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand(const half2& a, + const half2& b) { half a1 = __low2half(a); half a2 = __high2half(a); half b1 = __low2half(b); @@ -692,9 +712,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand(const half2& a, return __halves2half2(result1, result2); } -template <> -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 por(const half2& a, - const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 por(const half2& a, + const half2& b) { half a1 = __low2half(a); half a2 = __high2half(a); half b1 = __low2half(b); @@ -704,9 +723,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 por(const half2& a, return __halves2half2(result1, result2); } -template <> -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pxor(const half2& a, - const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pxor(const half2& a, + const half2& b) { half a1 = __low2half(a); half a2 = __high2half(a); half b1 = __low2half(b); @@ -716,9 +734,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pxor(const half2& a, return __halves2half2(result1, result2); } -template <> -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pandnot(const half2& a, - const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pandnot(const half2& a, + const half2& b) { half a1 = __low2half(a); half a2 = __high2half(a); half b1 = __low2half(b); @@ -728,7 +745,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pandnot(const half2& a, return __halves2half2(result1, result2); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd(const half2& a, const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd(const half2& a, + const half2& b) { #if defined(EIGEN_HIP_DEVICE_COMPILE) return __hadd2(a, b); @@ -750,7 +768,8 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd(const half2& #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psub(const half2& a, const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psub(const half2& a, + const half2& b) { #if defined(EIGEN_HIP_DEVICE_COMPILE) return __hsub2(a, b); @@ -772,7 +791,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psub(const half2& #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pnegate(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pnegate(const half2& a) { #if defined(EIGEN_HIP_DEVICE_COMPILE) return __hneg2(a); @@ -790,9 +809,10 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pnegate(const half2& a) { #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pconj(const half2& a) { return a; } +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pconj(const half2& a) { return a; } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul(const half2& a, const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul(const half2& a, + const half2& b) { #if defined(EIGEN_HIP_DEVICE_COMPILE) return __hmul2(a, b); @@ -814,7 +834,9 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul(const half2& #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmadd(const half2& a, const half2& b, const half2& c) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmadd(const half2& a, + const half2& b, + const half2& c) { #if defined(EIGEN_HIP_DEVICE_COMPILE) return __hfma2(a, b, c); @@ -838,7 +860,8 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmadd(const half2& #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv(const half2& a, const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv(const half2& a, + const half2& b) { #if defined(EIGEN_HIP_DEVICE_COMPILE) return __h2div(a, b); @@ -856,7 +879,8 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv(const half2& #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin(const half2& a, const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin(const half2& a, + const half2& b) { float a1 = __low2float(a); float a2 = __high2float(a); float b1 = __low2float(b); @@ -866,7 +890,8 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin(const half2& return __halves2half2(r1, r2); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax(const half2& a, const half2& b) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax(const half2& a, + const half2& b) { float a1 = __low2float(a); float a2 = __high2float(a); float b1 = __low2float(b); @@ -876,7 +901,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax(const half2& return __halves2half2(r1, r2); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux(const half2& a) { #if defined(EIGEN_HIP_DEVICE_COMPILE) return __hadd(__low2half(a), __high2half(a)); @@ -894,7 +919,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux(const #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max(const half2& a) { #if defined(EIGEN_HIP_DEVICE_COMPILE) __half first = __low2half(a); @@ -916,7 +941,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max(c #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min(const half2& a) { #if defined(EIGEN_HIP_DEVICE_COMPILE) __half first = __low2half(a); @@ -938,7 +963,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min(c #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul(const half2& a) { #if defined(EIGEN_HIP_DEVICE_COMPILE) return __hmul(__low2half(a), __high2half(a)); @@ -956,7 +981,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul(c #endif } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog1p(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog1p(const half2& a) { float a1 = __low2float(a); float a2 = __high2float(a); float r1 = log1pf(a1); @@ -964,7 +989,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog1p(const half2 return __floats2half2_rn(r1, r2); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexpm1(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexpm1(const half2& a) { float a1 = __low2float(a); float a2 = __high2float(a); float r1 = expm1f(a1); @@ -975,29 +1000,29 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexpm1(const half2 #if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \ defined(EIGEN_HIP_DEVICE_COMPILE) -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -half2 plog(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +half2 plog(const half2& a) { return h2log(a); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -half2 pexp(const half2& a) { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +half2 pexp(const half2& a) { return h2exp(a); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -half2 psqrt(const half2& a) { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +half2 psqrt(const half2& a) { return h2sqrt(a); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -half2 prsqrt(const half2& a) { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +half2 prsqrt(const half2& a) { return h2rsqrt(a); } #else -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog(const half2& a) { float a1 = __low2float(a); float a2 = __high2float(a); float r1 = logf(a1); @@ -1005,7 +1030,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog(const half2& return __floats2half2_rn(r1, r2); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp(const half2& a) { float a1 = __low2float(a); float a2 = __high2float(a); float r1 = expf(a1); @@ -1013,7 +1038,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp(const half2& return __floats2half2_rn(r1, r2); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psqrt(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psqrt(const half2& a) { float a1 = __low2float(a); float a2 = __high2float(a); float r1 = sqrtf(a1); @@ -1021,7 +1046,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psqrt(const half2& return __floats2half2_rn(r1, r2); } -template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 prsqrt(const half2& a) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 prsqrt(const half2& a) { float a1 = __low2float(a); float a2 = __high2float(a); float r1 = rsqrtf(a1); @@ -1029,8 +1054,728 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 prsqrt(const half2 return __floats2half2_rn(r1, r2); } #endif +} // namespace + + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pload(const Eigen::half* from) { + return *reinterpret_cast(from); +} + +// unaligned load; +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +ploadu(const Eigen::half* from) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = ploadu(from + 0); + p_alias[1] = ploadu(from + 2); + p_alias[2] = ploadu(from + 4); + p_alias[3] = ploadu(from + 6); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +ploaddup(const Eigen::half* from) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = ploaddup(from + 0); + p_alias[1] = ploaddup(from + 1); + p_alias[2] = ploaddup(from + 2); + p_alias[3] = ploaddup(from + 3); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore( + Eigen::half* to, const Packet4h2& from) { + *reinterpret_cast(to) = from; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu( + Eigen::half* to, const Packet4h2& from) { + const half2* from_alias = reinterpret_cast(&from); + pstoreu(to + 0,from_alias[0]); + pstoreu(to + 2,from_alias[1]); + pstoreu(to + 4,from_alias[2]); + pstoreu(to + 6,from_alias[3]); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4h2 +ploadt_ro(const Eigen::half* from) { +#if defined(EIGEN_HIP_DEVICE_COMPILE) + + Packet4h2 r; + r = __ldg((const Packet4h2*)from); + return r; +#else // EIGEN_CUDA_ARCH + +#if EIGEN_CUDA_ARCH >= 350 + Packet4h2 r; + r = __ldg((const Packet4h2*)from); + return r; +#else + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + r_alias[0] = ploadt_ro_aligned(from + 0); + r_alias[1] = ploadt_ro_aligned(from + 2); + r_alias[2] = ploadt_ro_aligned(from + 4); + r_alias[3] = ploadt_ro_aligned(from + 6); + return r; +#endif + +#endif +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4h2 +ploadt_ro(const Eigen::half* from) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + r_alias[0] = ploadt_ro_unaligned(from + 0); + r_alias[1] = ploadt_ro_unaligned(from + 2); + r_alias[2] = ploadt_ro_unaligned(from + 4); + r_alias[3] = ploadt_ro_unaligned(from + 6); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pgather(const Eigen::half* from, Index stride) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = __halves2half2(from[0 * stride], from[1 * stride]); + p_alias[1] = __halves2half2(from[2 * stride], from[3 * stride]); + p_alias[2] = __halves2half2(from[4 * stride], from[5 * stride]); + p_alias[3] = __halves2half2(from[6 * stride], from[7 * stride]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter( + Eigen::half* to, const Packet4h2& from, Index stride) { + const half2* from_alias = reinterpret_cast(&from); + pscatter(to + stride * 0, from_alias[0], stride); + pscatter(to + stride * 2, from_alias[1], stride); + pscatter(to + stride * 4, from_alias[2], stride); + pscatter(to + stride * 6, from_alias[3], stride); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst( + const Packet4h2& a) { + return pfirst(*(reinterpret_cast(&a))); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pabs( + const Packet4h2& a) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + p_alias[0] = pabs(a_alias[0]); + p_alias[1] = pabs(a_alias[1]); + p_alias[2] = pabs(a_alias[2]); + p_alias[3] = pabs(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 ptrue( + const Packet4h2& a) { + half true_half = half_impl::raw_uint16_to_half(0xffffu); + return pset1(true_half); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pzero(const Packet4h2& a) { + half false_half = half_impl::raw_uint16_to_half(0x0000u); + return pset1(false_half); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose_double( + double* d_row0, double* d_row1, double* d_row2, double* d_row3, + double* d_row4, double* d_row5, double* d_row6, double* d_row7) { + double d_tmp; + d_tmp = d_row0[1]; + d_row0[1] = d_row4[0]; + d_row4[0] = d_tmp; + + d_tmp = d_row1[1]; + d_row1[1] = d_row5[0]; + d_row5[0] = d_tmp; + + d_tmp = d_row2[1]; + d_row2[1] = d_row6[0]; + d_row6[0] = d_tmp; + + d_tmp = d_row3[1]; + d_row3[1] = d_row7[0]; + d_row7[0] = d_tmp; +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose_half2( + half2* f_row0, half2* f_row1, half2* f_row2, half2* f_row3) { + half2 f_tmp; + f_tmp = f_row0[1]; + f_row0[1] = f_row2[0]; + f_row2[0] = f_tmp; + + f_tmp = f_row1[1]; + f_row1[1] = f_row3[0]; + f_row3[0] = f_tmp; +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void +ptranspose_half(half2& f0, half2& f1) { + __half a1 = __low2half(f0); + __half a2 = __high2half(f0); + __half b1 = __low2half(f1); + __half b2 = __high2half(f1); + f0 = __halves2half2(a1, b1); + f1 = __halves2half2(a2, b2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + double* d_row0 = reinterpret_cast(&kernel.packet[0]); + double* d_row1 = reinterpret_cast(&kernel.packet[1]); + double* d_row2 = reinterpret_cast(&kernel.packet[2]); + double* d_row3 = reinterpret_cast(&kernel.packet[3]); + double* d_row4 = reinterpret_cast(&kernel.packet[4]); + double* d_row5 = reinterpret_cast(&kernel.packet[5]); + double* d_row6 = reinterpret_cast(&kernel.packet[6]); + double* d_row7 = reinterpret_cast(&kernel.packet[7]); + ptranspose_double(d_row0, d_row1, d_row2, d_row3, + d_row4, d_row5, d_row6, d_row7); + + + half2* f_row0 = reinterpret_cast(d_row0); + half2* f_row1 = reinterpret_cast(d_row1); + half2* f_row2 = reinterpret_cast(d_row2); + half2* f_row3 = reinterpret_cast(d_row3); + ptranspose_half2(f_row0, f_row1, f_row2, f_row3); + ptranspose_half(f_row0[0], f_row1[0]); + ptranspose_half(f_row0[1], f_row1[1]); + ptranspose_half(f_row2[0], f_row3[0]); + ptranspose_half(f_row2[1], f_row3[1]); + + f_row0 = reinterpret_cast(d_row0 + 1); + f_row1 = reinterpret_cast(d_row1 + 1); + f_row2 = reinterpret_cast(d_row2 + 1); + f_row3 = reinterpret_cast(d_row3 + 1); + ptranspose_half2(f_row0, f_row1, f_row2, f_row3); + ptranspose_half(f_row0[0], f_row1[0]); + ptranspose_half(f_row0[1], f_row1[1]); + ptranspose_half(f_row2[0], f_row3[0]); + ptranspose_half(f_row2[1], f_row3[1]); + + f_row0 = reinterpret_cast(d_row4); + f_row1 = reinterpret_cast(d_row5); + f_row2 = reinterpret_cast(d_row6); + f_row3 = reinterpret_cast(d_row7); + ptranspose_half2(f_row0, f_row1, f_row2, f_row3); + ptranspose_half(f_row0[0], f_row1[0]); + ptranspose_half(f_row0[1], f_row1[1]); + ptranspose_half(f_row2[0], f_row3[0]); + ptranspose_half(f_row2[1], f_row3[1]); + + f_row0 = reinterpret_cast(d_row4 + 1); + f_row1 = reinterpret_cast(d_row5 + 1); + f_row2 = reinterpret_cast(d_row6 + 1); + f_row3 = reinterpret_cast(d_row7 + 1); + ptranspose_half2(f_row0, f_row1, f_row2, f_row3); + ptranspose_half(f_row0[0], f_row1[0]); + ptranspose_half(f_row0[1], f_row1[1]); + ptranspose_half(f_row2[0], f_row3[0]); + ptranspose_half(f_row2[1], f_row3[1]); + +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +plset(const Eigen::half& a) { +#if defined(EIGEN_HIP_DEVICE_COMPILE) + + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = __halves2half2(a, __hadd(a, __float2half(1.0f))); + p_alias[1] = __halves2half2(__hadd(a, __float2half(2.0f)), + __hadd(a, __float2half(3.0f))); + p_alias[2] = __halves2half2(__hadd(a, __float2half(4.0f)), + __hadd(a, __float2half(5.0f))); + p_alias[3] = __halves2half2(__hadd(a, __float2half(6.0f)), + __hadd(a, __float2half(7.0f))); + return r; +#else // EIGEN_CUDA_ARCH + +#if EIGEN_CUDA_ARCH >= 530 + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + + half2 b = pset1(a); + half2 c; + half2 half_offset0 = __halves2half2(__float2half(0.0f),__float2half(2.0f)); + half2 half_offset1 = __halves2half2(__float2half(4.0f),__float2half(6.0f)); + + c = __hadd2(b, half_offset0); + r_alias[0] = plset(__low2half(c)); + r_alias[1] = plset(__high2half(c)); + + c = __hadd2(b, half_offset1); + r_alias[2] = plset(__low2half(c)); + r_alias[3] = plset(__high2half(c)); + + return r; + +#else + float f = __half2float(a); + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = __halves2half2(a, __float2half(f + 1.0f)); + p_alias[1] = __halves2half2(__float2half(f + 2.0f), __float2half(f + 3.0f)); + p_alias[2] = __halves2half2(__float2half(f + 4.0f), __float2half(f + 5.0f)); + p_alias[3] = __halves2half2(__float2half(f + 6.0f), __float2half(f + 7.0f)); + return r; +#endif #endif +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pselect(const Packet4h2& mask, const Packet4h2& a, + const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* mask_alias = reinterpret_cast(&mask); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pselect(mask_alias[0], a_alias[0], b_alias[0]); + r_alias[1] = pselect(mask_alias[1], a_alias[1], b_alias[1]); + r_alias[2] = pselect(mask_alias[2], a_alias[2], b_alias[2]); + r_alias[3] = pselect(mask_alias[3], a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pcmp_eq(const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pcmp_eq(a_alias[0], b_alias[0]); + r_alias[1] = pcmp_eq(a_alias[1], b_alias[1]); + r_alias[2] = pcmp_eq(a_alias[2], b_alias[2]); + r_alias[3] = pcmp_eq(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pand( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pand(a_alias[0], b_alias[0]); + r_alias[1] = pand(a_alias[1], b_alias[1]); + r_alias[2] = pand(a_alias[2], b_alias[2]); + r_alias[3] = pand(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 por( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = por(a_alias[0], b_alias[0]); + r_alias[1] = por(a_alias[1], b_alias[1]); + r_alias[2] = por(a_alias[2], b_alias[2]); + r_alias[3] = por(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pxor( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pxor(a_alias[0], b_alias[0]); + r_alias[1] = pxor(a_alias[1], b_alias[1]); + r_alias[2] = pxor(a_alias[2], b_alias[2]); + r_alias[3] = pxor(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pandnot(const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pandnot(a_alias[0], b_alias[0]); + r_alias[1] = pandnot(a_alias[1], b_alias[1]); + r_alias[2] = pandnot(a_alias[2], b_alias[2]); + r_alias[3] = pandnot(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 padd( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = padd(a_alias[0], b_alias[0]); + r_alias[1] = padd(a_alias[1], b_alias[1]); + r_alias[2] = padd(a_alias[2], b_alias[2]); + r_alias[3] = padd(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 psub( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = psub(a_alias[0], b_alias[0]); + r_alias[1] = psub(a_alias[1], b_alias[1]); + r_alias[2] = psub(a_alias[2], b_alias[2]); + r_alias[3] = psub(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pnegate(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = pnegate(a_alias[0]); + r_alias[1] = pnegate(a_alias[1]); + r_alias[2] = pnegate(a_alias[2]); + r_alias[3] = pnegate(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pconj(const Packet4h2& a) { + return a; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmul( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pmul(a_alias[0], b_alias[0]); + r_alias[1] = pmul(a_alias[1], b_alias[1]); + r_alias[2] = pmul(a_alias[2], b_alias[2]); + r_alias[3] = pmul(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmadd( + const Packet4h2& a, const Packet4h2& b, const Packet4h2& c) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + const half2* c_alias = reinterpret_cast(&c); + r_alias[0] = pmadd(a_alias[0], b_alias[0], c_alias[0]); + r_alias[1] = pmadd(a_alias[1], b_alias[1], c_alias[1]); + r_alias[2] = pmadd(a_alias[2], b_alias[2], c_alias[2]); + r_alias[3] = pmadd(a_alias[3], b_alias[3], c_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pdiv( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pdiv(a_alias[0], b_alias[0]); + r_alias[1] = pdiv(a_alias[1], b_alias[1]); + r_alias[2] = pdiv(a_alias[2], b_alias[2]); + r_alias[3] = pdiv(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmin( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pmin(a_alias[0], b_alias[0]); + r_alias[1] = pmin(a_alias[1], b_alias[1]); + r_alias[2] = pmin(a_alias[2], b_alias[2]); + r_alias[3] = pmin(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmax( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pmax(a_alias[0], b_alias[0]); + r_alias[1] = pmax(a_alias[1], b_alias[1]); + r_alias[2] = pmax(a_alias[2], b_alias[2]); + r_alias[3] = pmax(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux( + const Packet4h2& a) { + const half2* a_alias = reinterpret_cast(&a); + + return predux(a_alias[0]) + predux(a_alias[1]) + + predux(a_alias[2]) + predux(a_alias[3]); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max( + const Packet4h2& a) { + const half2* a_alias = reinterpret_cast(&a); + half2 m0 = __halves2half2(predux_max(a_alias[0]), + predux_max(a_alias[1])); + half2 m1 = __halves2half2(predux_max(a_alias[2]), + predux_max(a_alias[3])); + __half first = predux_max(m0); + __half second = predux_max(m1); +#if EIGEN_CUDA_ARCH >= 530 + return (__hgt(first, second) ? first : second); +#else + float ffirst = __half2float(first); + float fsecond = __half2float(second); + return (ffirst > fsecond)? first: second; +#endif +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min( + const Packet4h2& a) { + const half2* a_alias = reinterpret_cast(&a); + half2 m0 = __halves2half2(predux_min(a_alias[0]), + predux_min(a_alias[1])); + half2 m1 = __halves2half2(predux_min(a_alias[2]), + predux_min(a_alias[3])); + __half first = predux_min(m0); + __half second = predux_min(m1); +#if EIGEN_CUDA_ARCH >= 530 + return (__hlt(first, second) ? first : second); +#else + float ffirst = __half2float(first); + float fsecond = __half2float(second); + return (ffirst < fsecond)? first: second; +#endif +} + +// likely overflow/underflow +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul( + const Packet4h2& a) { + const half2* a_alias = reinterpret_cast(&a); + return predux_mul(pmul(pmul(a_alias[0], a_alias[1]), + pmul(a_alias[2], a_alias[3]))); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +plog1p(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = plog1p(a_alias[0]); + r_alias[1] = plog1p(a_alias[1]); + r_alias[2] = plog1p(a_alias[2]); + r_alias[3] = plog1p(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pexpm1(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = pexpm1(a_alias[0]); + r_alias[1] = pexpm1(a_alias[1]); + r_alias[2] = pexpm1(a_alias[2]); + r_alias[3] = pexpm1(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 plog(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = plog(a_alias[0]); + r_alias[1] = plog(a_alias[1]); + r_alias[2] = plog(a_alias[2]); + r_alias[3] = plog(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pexp(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = pexp(a_alias[0]); + r_alias[1] = pexp(a_alias[1]); + r_alias[2] = pexp(a_alias[2]); + r_alias[3] = pexp(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 psqrt(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = psqrt(a_alias[0]); + r_alias[1] = psqrt(a_alias[1]); + r_alias[2] = psqrt(a_alias[2]); + r_alias[3] = psqrt(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +prsqrt(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = prsqrt(a_alias[0]); + r_alias[1] = prsqrt(a_alias[1]); + r_alias[2] = prsqrt(a_alias[2]); + r_alias[3] = prsqrt(a_alias[3]); + return r; +} + +// The following specialized padd, pmul, pdiv, pmin, pmax, pset1 are needed for +// the implementation of GPU half reduction. +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd(const half2& a, + const half2& b) { +#if defined(EIGEN_HIP_DEVICE_COMPILE) + + return __hadd2(a, b); + +#else // EIGEN_CUDA_ARCH + +#if EIGEN_CUDA_ARCH >= 530 + return __hadd2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 + b1; + float r2 = a2 + b2; + return __floats2half2_rn(r1, r2); +#endif + +#endif +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul(const half2& a, + const half2& b) { +#if defined(EIGEN_HIP_DEVICE_COMPILE) + + return __hmul2(a, b); + +#else // EIGEN_CUDA_ARCH + +#if EIGEN_CUDA_ARCH >= 530 + return __hmul2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 * b1; + float r2 = a2 * b2; + return __floats2half2_rn(r1, r2); +#endif + +#endif +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv(const half2& a, + const half2& b) { +#if defined(EIGEN_HIP_DEVICE_COMPILE) + + return __h2div(a, b); + +#else // EIGEN_CUDA_ARCH + + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 / b1; + float r2 = a2 / b2; + return __floats2half2_rn(r1, r2); + +#endif +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin(const half2& a, + const half2& b) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + __half r1 = a1 < b1 ? __low2half(a) : __low2half(b); + __half r2 = a2 < b2 ? __high2half(a) : __high2half(b); + return __halves2half2(r1, r2); +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax(const half2& a, + const half2& b) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + __half r1 = a1 > b1 ? __low2half(a) : __low2half(b); + __half r2 = a2 > b2 ? __high2half(a) : __high2half(b); + return __halves2half2(r1, r2); +} + +#endif // defined(EIGEN_CUDA_ARCH) + +#endif // defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDACC) } // end namespace internal diff --git a/Eigen/src/Core/arch/GPU/TypeCasting.h b/Eigen/src/Core/arch/GPU/TypeCasting.h index c278f3fe8..754546225 100644 --- a/Eigen/src/Core/arch/GPU/TypeCasting.h +++ b/Eigen/src/Core/arch/GPU/TypeCasting.h @@ -17,12 +17,13 @@ namespace internal { #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + template <> struct type_casting_traits { enum { VectorizedCast = 1, - SrcCoeffRatio = 2, - TgtCoeffRatio = 1 + SrcCoeffRatio = 1, + TgtCoeffRatio = 2 }; }; @@ -32,15 +33,39 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast(con return make_float4(r1.x, r1.y, r2.x, r2.y); } + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pcast(const float4& a, const float4& b) { + Packet4h2 r; + half2* r_alias=reinterpret_cast(&r); + r_alias[0]=__floats2half2_rn(a.x,a.y); + r_alias[1]=__floats2half2_rn(a.z,a.w); + r_alias[2]=__floats2half2_rn(b.x,b.y); + r_alias[3]=__floats2half2_rn(b.z,b.w); + return r; +} + template <> struct type_casting_traits { enum { VectorizedCast = 1, - SrcCoeffRatio = 1, - TgtCoeffRatio = 2 + SrcCoeffRatio = 2, + TgtCoeffRatio = 1 }; }; +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast(const Packet4h2& a) { + // Simply discard the second half of the input + float4 r; + const half2* a_alias=reinterpret_cast(&a); + float2 r1 = __half22float2(a_alias[0]); + float2 r2 = __half22float2(a_alias[1]); + r.x=static_cast(r1.x); + r.y=static_cast(r1.y); + r.z=static_cast(r2.x); + r.w=static_cast(r2.y); + return r; +} + template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcast(const float4& a) { // Simply discard the second half of the input return __floats2half2_rn(a.x, a.y); -- cgit v1.2.3