diff options
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/MathFunctions.h')
-rw-r--r-- | Eigen/src/Core/arch/AVX512/MathFunctions.h | 56 |
1 files changed, 52 insertions, 4 deletions
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index 67043d01b..b86afced6 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -29,6 +29,12 @@ namespace internal { #define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \ const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X)) +#define _EIGEN_DECLARE_CONST_Packet16bf(NAME, X) \ + const Packet16bf p16bf_##NAME = pset1<Packet16bf>(X) + +#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \ + const Packet16bf p16bf_##NAME = preinterpret<Packet16bf,Packet16i>(pset1<Packet16i>(X)) + // Natural logarithm // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) // and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can @@ -128,6 +134,11 @@ plog<Packet16f>(const Packet16f& _x) { p16f_nan), p16f_minus_inf); } + +template <> +EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) { + return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x))); +} #endif // Exponential function. Works by writing "x = m*log(2) + r" where @@ -253,6 +264,10 @@ pexp<Packet8d>(const Packet8d& _x) { return pmax(pmul(x, e), _x); }*/ +template <> +EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) { + return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x))); +} // Functions for sqrt. // The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step @@ -303,12 +318,18 @@ template <> EIGEN_STRONG_INLINE Packet16f psqrt<Packet16f>(const Packet16f& x) { return _mm512_sqrt_ps(x); } + template <> EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) { return _mm512_sqrt_pd(x); } #endif +template <> +EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) { + return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x))); +} + // prsqrt for float. #if defined(EIGEN_VECTORIZE_AVX512ER) @@ -316,7 +337,6 @@ template <> EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) { return _mm512_rsqrt28_ps(x); } - #elif EIGEN_FAST_MATH template <> @@ -347,8 +367,7 @@ prsqrt<Packet16f>(const Packet16f& _x) { // For other arguments, choose the output of the intrinsic. This will // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf. return _mm512_mask_blend_ps(not_finite_pos_mask, y_newton, y_approx); - } - +} #else template <> @@ -356,9 +375,13 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) { _EIGEN_DECLARE_CONST_Packet16f(one, 1.0f); return _mm512_div_ps(p16f_one, _mm512_sqrt_ps(x)); } - #endif +template <> +EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) { + return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x))); +} + // prsqrt for double. #if EIGEN_FAST_MATH template <> @@ -412,10 +435,20 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) { return generic_plog1p(_x); } +template<> +EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) { + return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x))); +} + template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pexpm1<Packet16f>(const Packet16f& _x) { return generic_expm1(_x); } + +template<> +EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) { + return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x))); +} #endif #endif @@ -428,17 +461,32 @@ psin<Packet16f>(const Packet16f& _x) { } template <> +EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) { + return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x))); +} + +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pcos<Packet16f>(const Packet16f& _x) { return pcos_float(_x); } template <> +EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) { + return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x))); +} + +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f ptanh<Packet16f>(const Packet16f& _x) { return internal::generic_fast_tanh_float(_x); } +template <> +EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) { + return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x))); +} + } // end namespace internal } // end namespace Eigen |