diff options
Diffstat (limited to 'Eigen/src/Core/arch/CUDA/Half.h')
-rw-r--r-- | Eigen/src/Core/arch/CUDA/Half.h | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/CUDA/Half.h b/Eigen/src/Core/arch/CUDA/Half.h index 281b8e4c6..6387f2870 100644 --- a/Eigen/src/Core/arch/CUDA/Half.h +++ b/Eigen/src/Core/arch/CUDA/Half.h @@ -70,12 +70,18 @@ struct half : public __half { explicit EIGEN_DEVICE_FUNC half(bool b) : __half(internal::raw_uint16_to_half(b ? 0x3c00 : 0)) {} + explicit EIGEN_DEVICE_FUNC half(unsigned int ui) + : __half(internal::float_to_half_rtne(static_cast<float>(ui))) {} explicit EIGEN_DEVICE_FUNC half(int i) : __half(internal::float_to_half_rtne(static_cast<float>(i))) {} + explicit EIGEN_DEVICE_FUNC half(unsigned long ul) + : __half(internal::float_to_half_rtne(static_cast<float>(ul))) {} explicit EIGEN_DEVICE_FUNC half(long l) : __half(internal::float_to_half_rtne(static_cast<float>(l))) {} explicit EIGEN_DEVICE_FUNC half(long long ll) : __half(internal::float_to_half_rtne(static_cast<float>(ll))) {} + explicit EIGEN_DEVICE_FUNC half(unsigned long long ull) + : __half(internal::float_to_half_rtne(static_cast<float>(ull))) {} explicit EIGEN_DEVICE_FUNC half(float f) : __half(internal::float_to_half_rtne(f)) {} explicit EIGEN_DEVICE_FUNC half(double d) @@ -401,6 +407,7 @@ static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const Eigen::half& a) static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const Eigen::half& a) { return !(Eigen::numext::isinf)(a) && !(Eigen::numext::isnan)(a); } + template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half abs(const Eigen::half& a) { Eigen::half result; result.x = a.x & 0x7FFF; @@ -418,6 +425,18 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half sqrt(const Eigen::h template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half pow(const Eigen::half& a, const Eigen::half& b) { return Eigen::half(::powf(float(a), float(b))); } +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half sin(const Eigen::half& a) { + return Eigen::half(::sinf(float(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half cos(const Eigen::half& a) { + return Eigen::half(::cosf(float(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half tan(const Eigen::half& a) { + return Eigen::half(::tanf(float(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half tanh(const Eigen::half& a) { + return Eigen::half(::tanhf(float(a))); +} template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half floor(const Eigen::half& a) { return Eigen::half(::floorf(float(a))); } @@ -425,6 +444,51 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half ceil(const Eigen::h return Eigen::half(::ceilf(float(a))); } +template <> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half mini(const Eigen::half& a, const Eigen::half& b) { +#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hlt(b, a) ? b : a; +#else + const float f1 = static_cast<float>(a); + const float f2 = static_cast<float>(b); + return f2 < f1 ? b : a; +#endif +} +template <> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half maxi(const Eigen::half& a, const Eigen::half& b) { +#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hlt(a, b) ? b : a; +#else + const float f1 = static_cast<float>(a); + const float f2 = static_cast<float>(b); + return f1 < f2 ? b : a; +#endif +} + +#ifdef EIGEN_HAS_C99_MATH +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half lgamma(const Eigen::half& a) { + return Eigen::half(Eigen::numext::lgamma(static_cast<float>(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half digamma(const Eigen::half& a) { + return Eigen::half(Eigen::numext::digamma(static_cast<float>(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half zeta(const Eigen::half& x, const Eigen::half& q) { + return Eigen::half(Eigen::numext::zeta(static_cast<float>(x), static_cast<float>(q))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half polygamma(const Eigen::half& n, const Eigen::half& x) { + return Eigen::half(Eigen::numext::polygamma(static_cast<float>(n), static_cast<float>(x))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half erf(const Eigen::half& a) { + return Eigen::half(Eigen::numext::erf(static_cast<float>(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half erfc(const Eigen::half& a) { + return Eigen::half(Eigen::numext::erfc(static_cast<float>(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma(const Eigen::half& a, const Eigen::half& x) { + return Eigen::half(Eigen::numext::igamma(static_cast<float>(a), static_cast<float>(x))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igammac(const Eigen::half& a, const Eigen::half& x) { + return Eigen::half(Eigen::numext::igammac(static_cast<float>(a), static_cast<float>(x))); +} +#endif } // end namespace numext } // end namespace Eigen @@ -466,6 +530,11 @@ static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int (isfinite)(const Eigen::half& a namespace std { +EIGEN_ALWAYS_INLINE ostream& operator << (ostream& os, const Eigen::half& v) { + os << static_cast<float>(v); + return os; +} + #if __cplusplus > 199711L template <> struct hash<Eigen::half> { |