aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX512/MathFunctions.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-11-24 16:28:07 -0800
committerGravatar Antonio Sánchez <cantonios@google.com>2020-11-30 16:28:57 +0000
commit89f90b585d24b3c07946b4ffd8064e66ad5af94a (patch)
treec29344e3c03752faaaf2f8eee847811091688262 /Eigen/src/Core/arch/AVX512/MathFunctions.h
parentc5985c46f5de0a7a381262c5a8a973806db92f40 (diff)
AVX512 missing ops.
This allows the `packetmath` tests to pass for AVX512 on skylake. Made `half` and `bfloat16` consistent in terms of ops they support. Note the `log` tests are currently disabled for `bfloat16` since they fail due to poor precision (they were previously disabled for `Packet8bf` via test function specialization -- I just removed that specialization and disabled it in the generic test).
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/MathFunctions.h')
-rw-r--r--Eigen/src/Core/arch/AVX512/MathFunctions.h12
1 files changed, 11 insertions, 1 deletions
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index bfd30c01a..2c34868a7 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -48,6 +48,7 @@ plog<Packet8d>(const Packet8d& _x) {
return plog_double(_x);
}
+F16_PACKET_FUNCTION(Packet16f, Packet16h, plog)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
#endif
@@ -174,6 +175,7 @@ pexp<Packet8d>(const Packet8d& _x) {
return pmax(pmul(x, e), _x);
}*/
+F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
// Functions for sqrt.
@@ -232,6 +234,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
}
#endif
+F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
// prsqrt for float.
@@ -256,7 +259,7 @@ prsqrt<Packet16f>(const Packet16f& _x) {
__mmask16 inf_mask = _mm512_cmp_ps_mask(_x, p16f_inf, _CMP_EQ_OQ);
__mmask16 not_pos_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LE_OQ);
__mmask16 not_finite_pos_mask = not_pos_mask | inf_mask;
-
+
// Compute an approximate result using the rsqrt intrinsic, forcing +inf
// for denormals for consistency with AVX and SSE implementations.
Packet16f y_approx = _mm512_rsqrt14_ps(_x);
@@ -281,6 +284,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
}
#endif
+F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
// prsqrt for double.
@@ -336,6 +340,7 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
}
+F16_PACKET_FUNCTION(Packet16f, Packet16h, plog1p)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
@@ -343,6 +348,7 @@ Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
+F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
#endif
@@ -367,6 +373,10 @@ ptanh<Packet16f>(const Packet16f& _x) {
return internal::generic_fast_tanh_float(_x);
}
+F16_PACKET_FUNCTION(Packet16f, Packet16h, psin)
+F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos)
+F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh)
+
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)