diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-11-24 16:28:07 -0800 |
---|---|---|
committer | Antonio Sánchez <cantonios@google.com> | 2020-11-30 16:28:57 +0000 |
commit | 89f90b585d24b3c07946b4ffd8064e66ad5af94a (patch) | |
tree | c29344e3c03752faaaf2f8eee847811091688262 /Eigen/src/Core/arch/AVX512/MathFunctions.h | |
parent | c5985c46f5de0a7a381262c5a8a973806db92f40 (diff) |
AVX512 missing ops.
This allows the `packetmath` tests to pass for AVX512 on skylake.
Made `half` and `bfloat16` consistent in terms of ops they support.
Note the `log` tests are currently disabled for `bfloat16` since
they fail due to poor precision (they were previously disabled for
`Packet8bf` via test function specialization -- I just removed that
specialization and disabled it in the generic test).
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/MathFunctions.h')
-rw-r--r-- | Eigen/src/Core/arch/AVX512/MathFunctions.h | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index bfd30c01a..2c34868a7 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -48,6 +48,7 @@ plog<Packet8d>(const Packet8d& _x) { return plog_double(_x); } +F16_PACKET_FUNCTION(Packet16f, Packet16h, plog) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog) #endif @@ -174,6 +175,7 @@ pexp<Packet8d>(const Packet8d& _x) { return pmax(pmul(x, e), _x); }*/ +F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp) // Functions for sqrt. @@ -232,6 +234,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) { } #endif +F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt) // prsqrt for float. @@ -256,7 +259,7 @@ prsqrt<Packet16f>(const Packet16f& _x) { __mmask16 inf_mask = _mm512_cmp_ps_mask(_x, p16f_inf, _CMP_EQ_OQ); __mmask16 not_pos_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LE_OQ); __mmask16 not_finite_pos_mask = not_pos_mask | inf_mask; - + // Compute an approximate result using the rsqrt intrinsic, forcing +inf // for denormals for consistency with AVX and SSE implementations. Packet16f y_approx = _mm512_rsqrt14_ps(_x); @@ -281,6 +284,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) { } #endif +F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt) // prsqrt for double. @@ -336,6 +340,7 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) { return generic_plog1p(_x); } +F16_PACKET_FUNCTION(Packet16f, Packet16h, plog1p) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p) template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED @@ -343,6 +348,7 @@ Packet16f pexpm1<Packet16f>(const Packet16f& _x) { return generic_expm1(_x); } +F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1) #endif @@ -367,6 +373,10 @@ ptanh<Packet16f>(const Packet16f& _x) { return internal::generic_fast_tanh_float(_x); } +F16_PACKET_FUNCTION(Packet16f, Packet16h, psin) +F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos) +F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh) + BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh) |