aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsImpl.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-12-02 14:00:57 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2020-12-04 10:16:29 -0800
commite2f21465fea76a80966f12a20d0be36597f19b44 (patch)
tree1ae9b0e3ae489b028902166a343f796d196fde82 /unsupported/Eigen/src/SpecialFunctions/BesselFunctionsImpl.h
parent305b8bd2777bda99f65791468f305b76021bf579 (diff)
Special function implementations for half/bfloat16 packets.
Current implementations fail to consider half-float packets, only half-float scalars. Added specializations for packets on AVX, AVX512 and NEON. Added tests to `special_packetmath`. The current `special_functions` tests would fail for half and bfloat16 due to lack of precision. The NEON tests also fail with precision issues and due to different handling of `sqrt(inf)`, so special functions bessel, ndtri have been disabled. Tested with AVX, AVX512.
Diffstat (limited to 'unsupported/Eigen/src/SpecialFunctions/BesselFunctionsImpl.h')
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/BesselFunctionsImpl.h132
1 files changed, 66 insertions, 66 deletions
diff --git a/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsImpl.h
index a9b6ad940..24812be1b 100644
--- a/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsImpl.h
+++ b/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsImpl.h
@@ -46,7 +46,7 @@ struct bessel_i0e_retval {
typedef Scalar type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_i0e {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -201,11 +201,11 @@ struct generic_i0e<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_i0e_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_i0e<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_i0e<T>::run(x);
}
};
@@ -214,7 +214,7 @@ struct bessel_i0_retval {
typedef Scalar type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_i0 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T& x) {
@@ -224,11 +224,11 @@ struct generic_i0 {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_i0_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_i0<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_i0<T>::run(x);
}
};
@@ -237,7 +237,7 @@ struct bessel_i1e_retval {
typedef Scalar type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type >
struct generic_i1e {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -396,20 +396,20 @@ struct generic_i1e<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_i1e_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_i1e<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_i1e<T>::run(x);
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_i1_retval {
- typedef Scalar type;
+ typedef T type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_i1 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T& x) {
@@ -419,20 +419,20 @@ struct generic_i1 {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_i1_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_i1<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_i1<T>::run(x);
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_k0e_retval {
- typedef Scalar type;
+ typedef T type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_k0e {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -582,20 +582,20 @@ struct generic_k0e<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_k0e_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_k0e<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_k0e<T>::run(x);
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_k0_retval {
- typedef Scalar type;
+ typedef T type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_k0 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -754,20 +754,20 @@ struct generic_k0<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_k0_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_k0<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_k0<T>::run(x);
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_k1e_retval {
- typedef Scalar type;
+ typedef T type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_k1e {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -910,20 +910,20 @@ struct generic_k1e<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_k1e_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_k1e<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_k1e<T>::run(x);
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_k1_retval {
- typedef Scalar type;
+ typedef T type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_k1 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -1076,20 +1076,20 @@ struct generic_k1<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_k1_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_k1<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_k1<T>::run(x);
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_j0_retval {
- typedef Scalar type;
+ typedef T type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_j0 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -1276,20 +1276,20 @@ struct generic_j0<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_j0_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_j0<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_j0<T>::run(x);
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_y0_retval {
- typedef Scalar type;
+ typedef T type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_y0 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -1474,20 +1474,20 @@ struct generic_y0<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_y0_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_y0<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_y0<T>::run(x);
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_j1_retval {
- typedef Scalar type;
+ typedef T type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_j1 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -1665,20 +1665,20 @@ struct generic_j1<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_j1_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_j1<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_j1<T>::run(x);
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_y1_retval {
- typedef Scalar type;
+ typedef T type;
};
-template <typename T, typename ScalarType>
+template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_y1 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@@ -1868,11 +1868,11 @@ struct generic_y1<T, double> {
}
};
-template <typename Scalar>
+template <typename T>
struct bessel_y1_impl {
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
- return generic_y1<Scalar, Scalar>::run(x);
+ static EIGEN_STRONG_INLINE T run(const T x) {
+ return generic_y1<T>::run(x);
}
};