aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-11-23 16:11:01 -0800
committerGravatar Antonio Sánchez <cantonios@google.com>2020-11-24 16:46:41 +0000
commita3b300f1af7b2bb646c9e64162630ac164802ec8 (patch)
treeb7bb7c74e3f6350ff767172345fc4e089b062ef8 /Eigen/src/Core/arch/AVX
parent38abf2be4289a8da5db2d5b1db759f26800ae1d3 (diff)
Implement missing AVX half ops.
Minimal implementation of AVX `Eigen::half` ops to bring in line with `bfloat16`. Allows `packetmath_13` to pass. Also adjusted `bfloat16` packet traits to match the supported set of ops (e.g. Bessel is not actually implemented).
Diffstat (limited to 'Eigen/src/Core/arch/AVX')
-rw-r--r--Eigen/src/Core/arch/AVX/MathFunctions.h10
-rw-r--r--Eigen/src/Core/arch/AVX/PacketMath.h114
2 files changed, 101 insertions, 23 deletions
diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h
index 9b123db00..e2e704d82 100644
--- a/Eigen/src/Core/arch/AVX/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX/MathFunctions.h
@@ -158,6 +158,16 @@ Packet4d prsqrt<Packet4d>(const Packet4d& _x) {
return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x));
}
+F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, plog)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, plog1p)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, pexpm1)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
+
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index b68351356..e9eaaa9e0 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -119,22 +119,34 @@ struct packet_traits<Eigen::half> : default_packet_traits {
AlignedOnScalar = 1,
size = 8,
HasHalfPacket = 0,
+
+ HasCmp = 1,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
HasNegate = 1,
- HasAbs = 0,
+ HasAbs = 1,
HasAbs2 = 0,
- HasMin = 0,
- HasMax = 0,
- HasConj = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
HasSetLinear = 0,
- HasSqrt = 0,
- HasRsqrt = 0,
- HasExp = 0,
- HasLog = 0,
- HasBlend = 0
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
};
};
@@ -150,16 +162,24 @@ struct packet_traits<bfloat16> : default_packet_traits {
size = 8,
HasHalfPacket = 0,
- HasCmp = 1,
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
HasDiv = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasExp = 1,
- HasNdtri = 1,
- HasBessel = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasTanh = EIGEN_FAST_MATH,
@@ -870,8 +890,7 @@ template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, cons
}
// Packet math for Eigen::half
-// TODO(cantonios): add missing packet ops
-// - pabs, pmin, pmax, plset, pround, print, pceil, pfloor, pcmp_lt, pcmp_le, pcmp_lt_or_nan
+
template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; };
template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
@@ -914,6 +933,16 @@ ploadquad<Packet8h>(const Eigen::half* from) {
return _mm_set_epi16(b, b, b, b, a, a, a, a);
}
+template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) {
+ return _mm_cmpeq_epi32(a, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pabs(const Packet8h& a) {
+ const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_andnot_si128(sign_mask, a);
+}
+
EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) {
#ifdef EIGEN_HAS_FP16_C
return _mm256_cvtph_ps(a);
@@ -951,8 +980,21 @@ EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) {
#endif
}
-template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) {
- return _mm_cmpeq_epi32(a, a);
+template <>
+EIGEN_STRONG_INLINE Packet8h pmin<Packet8h>(const Packet8h& a,
+ const Packet8h& b) {
+ return float2half(pmin<Packet8f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pmax<Packet8h>(const Packet8h& a,
+ const Packet8h& b) {
+ return float2half(pmax<Packet8f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h plset<Packet8h>(const half& a) {
+ return float2half(plset<Packet8f>(static_cast<float>(a)));
}
template<> EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a,const Packet8h& b) {
@@ -974,13 +1016,36 @@ template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Pack
return _mm_blendv_epi8(b, a, mask);
}
+template<> EIGEN_STRONG_INLINE Packet8h pround<Packet8h>(const Packet8h& a) {
+ return float2half(pround<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h print<Packet8h>(const Packet8h& a) {
+ return float2half(print<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pceil<Packet8h>(const Packet8h& a) {
+ return float2half(pceil<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pfloor<Packet8h>(const Packet8h& a) {
+ return float2half(pfloor<Packet8f>(half2float(a)));
+}
+
template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) {
- Packet8f af = half2float(a);
- Packet8f bf = half2float(b);
- Packet8f rf = pcmp_eq(af, bf);
- // Pack the 32-bit flags into 16-bits flags.
- return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0),
- _mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
+ return Pack16To8(pcmp_eq(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_le(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_lt(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b)));
}
template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; }
@@ -1148,6 +1213,8 @@ ptranspose(PacketBlock<Packet8h,4>& kernel) {
kernel.packet[3] = pload<Packet8h>(out[3]);
}
+// BFloat16 implementation.
+
EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
#ifdef EIGEN_VECTORIZE_AVX2
__m256i extend = _mm256_cvtepu16_epi32(a);
@@ -1262,7 +1329,8 @@ template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) {
template <>
EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
- return F32ToBf16(pabs<Packet8f>(Bf16ToF32(a)));
+ const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_andnot_si128(sign_mask, a);
}
template <>