diff options
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/MathFunctions.h')
-rw-r--r-- | Eigen/src/Core/arch/AVX512/MathFunctions.h | 47 |
1 files changed, 9 insertions, 38 deletions
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index b86afced6..83af5f5de 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -135,10 +135,7 @@ plog<Packet16f>(const Packet16f& _x) { p16f_minus_inf); } -template <> -EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog) #endif // Exponential function. Works by writing "x = m*log(2) + r" where @@ -264,10 +261,7 @@ pexp<Packet8d>(const Packet8d& _x) { return pmax(pmul(x, e), _x); }*/ -template <> -EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp) // Functions for sqrt. // The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step @@ -325,10 +319,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) { } #endif -template <> -EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) { - return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt) // prsqrt for float. #if defined(EIGEN_VECTORIZE_AVX512ER) @@ -377,10 +368,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) { } #endif -template <> -EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) { - return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt) // prsqrt for double. #if EIGEN_FAST_MATH @@ -435,20 +423,14 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) { return generic_plog1p(_x); } -template<> -EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p) template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pexpm1<Packet16f>(const Packet16f& _x) { return generic_expm1(_x); } -template<> -EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1) #endif #endif @@ -461,31 +443,20 @@ psin<Packet16f>(const Packet16f& _x) { } template <> -EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x))); -} - -template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pcos<Packet16f>(const Packet16f& _x) { return pcos_float(_x); } template <> -EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x))); -} - -template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f ptanh<Packet16f>(const Packet16f& _x) { return internal::generic_fast_tanh_float(_x); } -template <> -EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh) } // end namespace internal |