aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX512/MathFunctions.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/MathFunctions.h')
-rw-r--r--Eigen/src/Core/arch/AVX512/MathFunctions.h56
1 files changed, 52 insertions, 4 deletions
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index 67043d01b..b86afced6 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -29,6 +29,12 @@ namespace internal {
#define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \
const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X))
+#define _EIGEN_DECLARE_CONST_Packet16bf(NAME, X) \
+ const Packet16bf p16bf_##NAME = pset1<Packet16bf>(X)
+
+#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \
+ const Packet16bf p16bf_##NAME = preinterpret<Packet16bf,Packet16i>(pset1<Packet16i>(X))
+
// Natural logarithm
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
@@ -128,6 +134,11 @@ plog<Packet16f>(const Packet16f& _x) {
p16f_nan),
p16f_minus_inf);
}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x)));
+}
#endif
// Exponential function. Works by writing "x = m*log(2) + r" where
@@ -253,6 +264,10 @@ pexp<Packet8d>(const Packet8d& _x) {
return pmax(pmul(x, e), _x);
}*/
+template <>
+EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x)));
+}
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
@@ -303,12 +318,18 @@ template <>
EIGEN_STRONG_INLINE Packet16f psqrt<Packet16f>(const Packet16f& x) {
return _mm512_sqrt_ps(x);
}
+
template <>
EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
return _mm512_sqrt_pd(x);
}
#endif
+template <>
+EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) {
+ return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x)));
+}
+
// prsqrt for float.
#if defined(EIGEN_VECTORIZE_AVX512ER)
@@ -316,7 +337,6 @@ template <>
EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
return _mm512_rsqrt28_ps(x);
}
-
#elif EIGEN_FAST_MATH
template <>
@@ -347,8 +367,7 @@ prsqrt<Packet16f>(const Packet16f& _x) {
// For other arguments, choose the output of the intrinsic. This will
// return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf.
return _mm512_mask_blend_ps(not_finite_pos_mask, y_newton, y_approx);
- }
-
+}
#else
template <>
@@ -356,9 +375,13 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
_EIGEN_DECLARE_CONST_Packet16f(one, 1.0f);
return _mm512_div_ps(p16f_one, _mm512_sqrt_ps(x));
}
-
#endif
+template <>
+EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) {
+ return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x)));
+}
+
// prsqrt for double.
#if EIGEN_FAST_MATH
template <>
@@ -412,10 +435,20 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
}
+template<>
+EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x)));
+}
+
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
+
+template<>
+EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x)));
+}
#endif
#endif
@@ -428,17 +461,32 @@ psin<Packet16f>(const Packet16f& _x) {
}
template <>
+EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x)));
+}
+
+template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
pcos<Packet16f>(const Packet16f& _x) {
return pcos_float(_x);
}
template <>
+EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x)));
+}
+
+template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
ptanh<Packet16f>(const Packet16f& _x) {
return internal::generic_fast_tanh_float(_x);
}
+template <>
+EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x)));
+}
+
} // end namespace internal
} // end namespace Eigen