aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/CUDA/Half.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/arch/CUDA/Half.h')
-rw-r--r--Eigen/src/Core/arch/CUDA/Half.h69
1 files changed, 69 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Half.h b/Eigen/src/Core/arch/CUDA/Half.h
index 281b8e4c6..6387f2870 100644
--- a/Eigen/src/Core/arch/CUDA/Half.h
+++ b/Eigen/src/Core/arch/CUDA/Half.h
@@ -70,12 +70,18 @@ struct half : public __half {
explicit EIGEN_DEVICE_FUNC half(bool b)
: __half(internal::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
+ explicit EIGEN_DEVICE_FUNC half(unsigned int ui)
+ : __half(internal::float_to_half_rtne(static_cast<float>(ui))) {}
explicit EIGEN_DEVICE_FUNC half(int i)
: __half(internal::float_to_half_rtne(static_cast<float>(i))) {}
+ explicit EIGEN_DEVICE_FUNC half(unsigned long ul)
+ : __half(internal::float_to_half_rtne(static_cast<float>(ul))) {}
explicit EIGEN_DEVICE_FUNC half(long l)
: __half(internal::float_to_half_rtne(static_cast<float>(l))) {}
explicit EIGEN_DEVICE_FUNC half(long long ll)
: __half(internal::float_to_half_rtne(static_cast<float>(ll))) {}
+ explicit EIGEN_DEVICE_FUNC half(unsigned long long ull)
+ : __half(internal::float_to_half_rtne(static_cast<float>(ull))) {}
explicit EIGEN_DEVICE_FUNC half(float f)
: __half(internal::float_to_half_rtne(f)) {}
explicit EIGEN_DEVICE_FUNC half(double d)
@@ -401,6 +407,7 @@ static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const Eigen::half& a)
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const Eigen::half& a) {
return !(Eigen::numext::isinf)(a) && !(Eigen::numext::isnan)(a);
}
+
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half abs(const Eigen::half& a) {
Eigen::half result;
result.x = a.x & 0x7FFF;
@@ -418,6 +425,18 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half sqrt(const Eigen::h
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half pow(const Eigen::half& a, const Eigen::half& b) {
return Eigen::half(::powf(float(a), float(b)));
}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half sin(const Eigen::half& a) {
+ return Eigen::half(::sinf(float(a)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half cos(const Eigen::half& a) {
+ return Eigen::half(::cosf(float(a)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half tan(const Eigen::half& a) {
+ return Eigen::half(::tanf(float(a)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half tanh(const Eigen::half& a) {
+ return Eigen::half(::tanhf(float(a)));
+}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half floor(const Eigen::half& a) {
return Eigen::half(::floorf(float(a)));
}
@@ -425,6 +444,51 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half ceil(const Eigen::h
return Eigen::half(::ceilf(float(a)));
}
+template <> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half mini(const Eigen::half& a, const Eigen::half& b) {
+#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
+ return __hlt(b, a) ? b : a;
+#else
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f2 < f1 ? b : a;
+#endif
+}
+template <> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half maxi(const Eigen::half& a, const Eigen::half& b) {
+#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
+ return __hlt(a, b) ? b : a;
+#else
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f1 < f2 ? b : a;
+#endif
+}
+
+#ifdef EIGEN_HAS_C99_MATH
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half lgamma(const Eigen::half& a) {
+ return Eigen::half(Eigen::numext::lgamma(static_cast<float>(a)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half digamma(const Eigen::half& a) {
+ return Eigen::half(Eigen::numext::digamma(static_cast<float>(a)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half zeta(const Eigen::half& x, const Eigen::half& q) {
+ return Eigen::half(Eigen::numext::zeta(static_cast<float>(x), static_cast<float>(q)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half polygamma(const Eigen::half& n, const Eigen::half& x) {
+ return Eigen::half(Eigen::numext::polygamma(static_cast<float>(n), static_cast<float>(x)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half erf(const Eigen::half& a) {
+ return Eigen::half(Eigen::numext::erf(static_cast<float>(a)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half erfc(const Eigen::half& a) {
+ return Eigen::half(Eigen::numext::erfc(static_cast<float>(a)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma(const Eigen::half& a, const Eigen::half& x) {
+ return Eigen::half(Eigen::numext::igamma(static_cast<float>(a), static_cast<float>(x)));
+}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igammac(const Eigen::half& a, const Eigen::half& x) {
+ return Eigen::half(Eigen::numext::igammac(static_cast<float>(a), static_cast<float>(x)));
+}
+#endif
} // end namespace numext
} // end namespace Eigen
@@ -466,6 +530,11 @@ static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int (isfinite)(const Eigen::half& a
namespace std {
+EIGEN_ALWAYS_INLINE ostream& operator << (ostream& os, const Eigen::half& v) {
+ os << static_cast<float>(v);
+ return os;
+}
+
#if __cplusplus > 199711L
template <>
struct hash<Eigen::half> {