aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/GPU
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-06-20 11:47:49 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-06-20 11:47:49 -0700
commit988f24b730fe812e2e31d332d33277752fba435d (patch)
tree04bc9152e7956bbc47e3ed1618b202afb9f68913 /Eigen/src/Core/arch/GPU
parente0be7f30e137eba21bbde7b3c20300ce74b637b4 (diff)
Various fixes for packet ops.
1. Fix buggy pcmp_eq and unit test for half types. 2. Add unit test for pselect and add specializations for SSE 4.1, AVX512, and half types. 3. Get rid of FIXME: Implement faster pnegate for half by XOR'ing with a sign bit mask.
Diffstat (limited to 'Eigen/src/Core/arch/GPU')
-rw-r--r--Eigen/src/Core/arch/GPU/PacketMathHalf.h46
1 files changed, 36 insertions, 10 deletions
diff --git a/Eigen/src/Core/arch/GPU/PacketMathHalf.h b/Eigen/src/Core/arch/GPU/PacketMathHalf.h
index b04a4d7d6..3273c5ea2 100644
--- a/Eigen/src/Core/arch/GPU/PacketMathHalf.h
+++ b/Eigen/src/Core/arch/GPU/PacketMathHalf.h
@@ -177,6 +177,15 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset<half2>(const Eigen:
}
template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect<half2>(const half2& mask,
+ const half2& a,
+ const half2& b) {
+ half result_low = __low2half(mask) == half(0) ? __low2half(b) : __low2half(a);
+ half result_high = __high2half(mask) == half(0) ? __high2half(b) : __high2half(a);
+ return __halves2half2(result_low, result_high);
+}
+
+template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq<half2>(const half2& a,
const half2& b) {
half true_half = half_impl::raw_uint16_to_half(0xffffu);
@@ -726,18 +735,29 @@ template<> EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a,const Packet
Packet16h r; r.x = pandnot(Packet8i(a.x),Packet8i(b.x)); return r;
}
+template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) {
+ Packet16h r; r.x = _mm256_blendv_epi8(b.x, a.x, mask.x); return r;
+}
+
template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
Packet16f rf = pcmp_eq(af, bf);
- return float2half(rf);
+ // Pack the 32-bit flags into 16-bits flags.
+ __m256i lo = _mm256_castps_si256(extract256<0>(rf));
+ __m256i hi = _mm256_castps_si256(extract256<1>(rf));
+ __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
+ _mm256_extractf128_si256(lo, 1));
+ __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
+ _mm256_extractf128_si256(hi, 1));
+ Packet16h result; result.x = _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
+ return result;
}
template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
- // FIXME we could do that with bit manipulation
- Packet16f af = half2float(a);
- Packet16f rf = pnegate(af);
- return float2half(rf);
+ Packet16h sign_mask; sign_mask.x = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
+ Packet16h result; result.x = _mm256_xor_si256(a.x, sign_mask.x);
+ return result;
}
template<> EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
@@ -1182,20 +1202,26 @@ template<> EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a,const Packet8h
Packet8h r; r.x = _mm_andnot_si128(b.x,a.x); return r;
}
+template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) {
+ Packet8h r; r.x = _mm_blendv_epi8(b.x, a.x, mask.x); return r;
+}
+
template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) {
Packet8f af = half2float(a);
Packet8f bf = half2float(b);
Packet8f rf = pcmp_eq(af, bf);
- return float2half(rf);
+ // Pack the 32-bit flags into 16-bits flags.
+ Packet8h result; result.x = _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0),
+ _mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
+ return result;
}
template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) {
- // FIXME we could do that with bit manipulation
- Packet8f af = half2float(a);
- Packet8f rf = pnegate(af);
- return float2half(rf);
+ Packet8h sign_mask; sign_mask.x = _mm_set1_epi16(static_cast<unsigned short>(0x8000));
+ Packet8h result; result.x = _mm_xor_si128(a.x, sign_mask.x);
+ return result;
}
template<> EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {