aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-12-02 14:00:57 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2020-12-04 10:16:29 -0800
commite2f21465fea76a80966f12a20d0be36597f19b44 (patch)
tree1ae9b0e3ae489b028902166a343f796d196fde82 /unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
parent305b8bd2777bda99f65791468f305b76021bf579 (diff)
Special function implementations for half/bfloat16 packets.
Current implementations fail to consider half-float packets, only half-float scalars. Added specializations for packets on AVX, AVX512 and NEON. Added tests to `special_packetmath`. The current `special_functions` tests would fail for half and bfloat16 due to lack of precision. The NEON tests also fail with precision issues and due to different handling of `sqrt(inf)`, so special functions bessel, ndtri have been disabled. Tested with AVX, AVX512.
Diffstat (limited to 'unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h')
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h5
1 files changed, 3 insertions, 2 deletions
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
index 648eb053e..cfc13aff7 100644
--- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
+++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
@@ -348,7 +348,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
template <typename T>
struct erf_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE T run(const T x) {
+ static EIGEN_STRONG_INLINE T run(const T& x) {
return generic_fast_erf_float(x);
}
};
@@ -490,7 +490,8 @@ struct erfc_impl<double> {
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T flipsign(
const T& should_flipsign, const T& x) {
- const T sign_mask = pset1<T>(-0.0);
+ typedef typename unpacket_traits<T>::type Scalar;
+ const T sign_mask = pset1<T>(Scalar(-0.0));
T sign_bit = pand<T>(should_flipsign, sign_mask);
return pxor<T>(sign_bit, x);
}