From 386d809bde475c65b7940f290efe80e6a05878c4 Mon Sep 17 00:00:00 2001 From: Teng Lu Date: Sat, 20 Jun 2020 19:16:24 +0000 Subject: Support BFloat16 in Eigen --- CMakeLists.txt | 11 +- Eigen/Core | 5 + Eigen/src/Core/arch/AVX512/MathFunctions.h | 56 +- Eigen/src/Core/arch/AVX512/PacketMath.h | 516 ++++++++++++++- Eigen/src/Core/arch/AVX512/TypeCasting.h | 26 + Eigen/src/Core/arch/Default/BFloat16.h | 703 +++++++++++++++++++++ Eigen/src/Core/arch/Default/TypeCasting.h | 43 ++ Eigen/src/Core/util/ConfigureVectorization.h | 3 + test/CMakeLists.txt | 1 + test/bfloat16_float.cpp | 399 ++++++++++++ test/main.h | 1 + test/numext.cpp | 1 + test/packetmath.cpp | 1 + test/packetmath_test_shared.h | 1 + unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h | 11 + unsupported/Eigen/SpecialFunctions | 2 + .../src/SpecialFunctions/BesselFunctionsBFloat16.h | 68 ++ .../SpecialFunctions/SpecialFunctionsBFloat16.h | 58 ++ unsupported/test/cxx11_tensor_reduction.cpp | 1 + 19 files changed, 1893 insertions(+), 14 deletions(-) create mode 100644 Eigen/src/Core/arch/Default/BFloat16.h create mode 100644 test/bfloat16_float.cpp create mode 100644 unsupported/Eigen/src/SpecialFunctions/BesselFunctionsBFloat16.h create mode 100644 unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsBFloat16.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 21eece2c8..28103856e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -241,13 +241,22 @@ if(NOT MSVC) option(EIGEN_TEST_AVX512 "Enable/Disable AVX512 in tests/examples" OFF) if(EIGEN_TEST_AVX512) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma -DEIGEN_ENABLE_AVX512") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma") if (NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fabi-version=6") endif() message(STATUS "Enabling AVX512 in tests/examples") endif() + option(EIGEN_TEST_AVX512DQ "Enable/Disable AVX512DQ in tests/examples" OFF) + if(EIGEN_TEST_AVX512DQ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512dq") + if (NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fabi-version=6") + endif() + message(STATUS "Enabling AVX512DQ in tests/examples") + endif() + option(EIGEN_TEST_F16C "Enable/Disable F16C in tests/examples" OFF) if(EIGEN_TEST_F16C) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c") diff --git a/Eigen/Core b/Eigen/Core index f36031557..f44b77831 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -51,6 +51,10 @@ #define EIGEN_HAS_GPU_FP16 #endif +#if defined(EIGEN_HAS_CUDA_BF16) || defined(EIGEN_HAS_HIP_BF16) + #define EIGEN_HAS_GPU_BF16 +#endif + #if (defined _OPENMP) && (!defined EIGEN_DONT_PARALLELIZE) #define EIGEN_HAS_OPENMP #endif @@ -163,6 +167,7 @@ using std::ptrdiff_t; #include "src/Core/arch/Default/ConjHelper.h" // Generic half float support #include "src/Core/arch/Default/Half.h" +#include "src/Core/arch/Default/BFloat16.h" #include "src/Core/arch/Default/TypeCasting.h" #include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h" diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index 67043d01b..b86afced6 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -29,6 +29,12 @@ namespace internal { #define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \ const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X)) +#define _EIGEN_DECLARE_CONST_Packet16bf(NAME, X) \ + const Packet16bf p16bf_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \ + const Packet16bf p16bf_##NAME = preinterpret(pset1(X)) + // Natural logarithm // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) // and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can @@ -128,6 +134,11 @@ plog(const Packet16f& _x) { p16f_nan), p16f_minus_inf); } + +template <> +EIGEN_STRONG_INLINE Packet16bf plog(const Packet16bf& _x) { + return F32ToBf16(plog(Bf16ToF32(_x))); +} #endif // Exponential function. Works by writing "x = m*log(2) + r" where @@ -253,6 +264,10 @@ pexp(const Packet8d& _x) { return pmax(pmul(x, e), _x); }*/ +template <> +EIGEN_STRONG_INLINE Packet16bf pexp(const Packet16bf& _x) { + return F32ToBf16(pexp(Bf16ToF32(_x))); +} // Functions for sqrt. // The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step @@ -303,12 +318,18 @@ template <> EIGEN_STRONG_INLINE Packet16f psqrt(const Packet16f& x) { return _mm512_sqrt_ps(x); } + template <> EIGEN_STRONG_INLINE Packet8d psqrt(const Packet8d& x) { return _mm512_sqrt_pd(x); } #endif +template <> +EIGEN_STRONG_INLINE Packet16bf psqrt(const Packet16bf& x) { + return F32ToBf16(psqrt(Bf16ToF32(x))); +} + // prsqrt for float. #if defined(EIGEN_VECTORIZE_AVX512ER) @@ -316,7 +337,6 @@ template <> EIGEN_STRONG_INLINE Packet16f prsqrt(const Packet16f& x) { return _mm512_rsqrt28_ps(x); } - #elif EIGEN_FAST_MATH template <> @@ -347,8 +367,7 @@ prsqrt(const Packet16f& _x) { // For other arguments, choose the output of the intrinsic. This will // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf. return _mm512_mask_blend_ps(not_finite_pos_mask, y_newton, y_approx); - } - +} #else template <> @@ -356,9 +375,13 @@ EIGEN_STRONG_INLINE Packet16f prsqrt(const Packet16f& x) { _EIGEN_DECLARE_CONST_Packet16f(one, 1.0f); return _mm512_div_ps(p16f_one, _mm512_sqrt_ps(x)); } - #endif +template <> +EIGEN_STRONG_INLINE Packet16bf prsqrt(const Packet16bf& x) { + return F32ToBf16(prsqrt(Bf16ToF32(x))); +} + // prsqrt for double. #if EIGEN_FAST_MATH template <> @@ -412,10 +435,20 @@ Packet16f plog1p(const Packet16f& _x) { return generic_plog1p(_x); } +template<> +EIGEN_STRONG_INLINE Packet16bf plog1p(const Packet16bf& _x) { + return F32ToBf16(plog1p(Bf16ToF32(_x))); +} + template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pexpm1(const Packet16f& _x) { return generic_expm1(_x); } + +template<> +EIGEN_STRONG_INLINE Packet16bf pexpm1(const Packet16bf& _x) { + return F32ToBf16(pexpm1(Bf16ToF32(_x))); +} #endif #endif @@ -427,18 +460,33 @@ psin(const Packet16f& _x) { return psin_float(_x); } +template <> +EIGEN_STRONG_INLINE Packet16bf psin(const Packet16bf& _x) { + return F32ToBf16(psin(Bf16ToF32(_x))); +} + template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pcos(const Packet16f& _x) { return pcos_float(_x); } +template <> +EIGEN_STRONG_INLINE Packet16bf pcos(const Packet16bf& _x) { + return F32ToBf16(pcos(Bf16ToF32(_x))); +} + template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f ptanh(const Packet16f& _x) { return internal::generic_fast_tanh_float(_x); } +template <> +EIGEN_STRONG_INLINE Packet16bf ptanh(const Packet16bf& _x) { + return F32ToBf16(ptanh(Bf16ToF32(_x))); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index ad37ad620..ed15e0889 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -362,6 +362,25 @@ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { } #endif +// Helper function for bit packing snippet of low precision comparison. +// It packs the flags from 32x16 to 16x16. +EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) { + // Split data into small pieces and handle with AVX instructions + // to guarantee internal order of vector. + // Operation: + // dst[15:0] := Saturate16(rf[31:0]) + // dst[31:16] := Saturate16(rf[63:32]) + // ... + // dst[255:240] := Saturate16(rf[255:224]) + __m256i lo = _mm256_castps_si256(extract256<0>(rf)); + __m256i hi = _mm256_castps_si256(extract256<1>(rf)); + __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0), + _mm256_extractf128_si256(lo, 1)); + __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0), + _mm256_extractf128_si256(hi, 1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); +} + template <> EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) { __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ); @@ -1342,15 +1361,7 @@ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Pa template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) { Packet16f af = half2float(a); Packet16f bf = half2float(b); - Packet16f rf = pcmp_eq(af, bf); - // Pack the 32-bit flags into 16-bits flags. - __m256i lo = _mm256_castps_si256(extract256<0>(rf)); - __m256i hi = _mm256_castps_si256(extract256<1>(rf)); - __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0), - _mm256_extractf128_si256(lo, 1)); - __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0), - _mm256_extractf128_si256(hi, 1)); - return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); + return Pack32To16(pcmp_eq(af, bf)); } template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) { @@ -1607,6 +1618,493 @@ ptranspose(PacketBlock& kernel) { kernel.packet[3] = pload(out[3]); } +typedef union { +#ifdef EIGEN_VECTORIZE_AVX512BF16 + __m256bh bh; +#endif + Packet8i i; // __m256i; +} Packet16bf; + +template <> struct is_arithmetic { enum { value = true }; }; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet16bf type; + // There is no half-size packet for current Packet16bf. + // TODO: support as SSE/AVX path. + typedef Packet16bf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 0, + HasBlend = 0, + HasInsert = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, +#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) +#ifdef EIGEN_VECTORIZE_AVX512DQ + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasNdtri = 1, + HasBessel = 1, +#endif + HasExp = 1, + HasSqrt = EIGEN_FAST_MATH, + HasRsqrt = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, +#endif + HasDiv = 1 + }; +}; + +template <> +struct unpacket_traits +{ + typedef bfloat16 type; + enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; + typedef Packet16bf half; +}; + +template <> +EIGEN_STRONG_INLINE Packet16bf pset1(const bfloat16& from) { + Packet16bf r; + r.i = _mm256_set1_epi16(from.value); + return r; +} + +template <> +EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet16bf& from) { + bfloat16 t; + t.value = static_cast(_mm256_extract_epi16(from.i, 0)); + return t; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pload(const bfloat16* from) { + Packet16bf r; + r.i = _mm256_load_si256(reinterpret_cast(from)); + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf ploadu(const bfloat16* from) { + Packet16bf r; + r.i = _mm256_loadu_si256(reinterpret_cast(from)); + return r; +} + +template <> +EIGEN_STRONG_INLINE void pstore(bfloat16* to, + const Packet16bf& from) { + _mm256_store_si256(reinterpret_cast<__m256i*>(to), from.i); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, + const Packet16bf& from) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from.i); +} + +template<> EIGEN_STRONG_INLINE Packet16bf +ploaddup(const bfloat16* from) { + Packet16bf r; + unsigned short a = from[0].value; + unsigned short b = from[1].value; + unsigned short c = from[2].value; + unsigned short d = from[3].value; + unsigned short e = from[4].value; + unsigned short f = from[5].value; + unsigned short g = from[6].value; + unsigned short h = from[7].value; + r.i = _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a); + return r; +} + +template<> EIGEN_STRONG_INLINE Packet16bf +ploadquad(const bfloat16* from) { + Packet16bf r; + unsigned short a = from[0].value; + unsigned short b = from[1].value; + unsigned short c = from[2].value; + unsigned short d = from[3].value; + r.i = _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a); + return r; +} + +EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) { + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16)); +} + +// Convert float to bfloat16 according to round-to-even/denormals alogrithm. +EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { + Packet16bf r; + + // Flush input denormals value to zero with hardware capability. + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); + __m512 flush = _mm512_and_ps(a, a); + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); + +#if defined(EIGEN_VECTORIZE_AVX512BF16) + r.bh = _mm512_cvtneps_pbh(flush); +#else + __m512i t; + __m512i input = _mm512_castps_si512(flush); + __m512i nan = _mm512_set1_epi32(0x7fc0); + + // uint32_t lsb = (input >> 16) & 1; + t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1)); + // uint32_t rounding_bias = 0x7fff + lsb; + t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff)); + // input += rounding_bias; + t = _mm512_add_epi32(t, input); + // input = input >> 16; + t = _mm512_srli_epi32(t, 16); + + // Check NaN before converting back to bf16 + __mmask16 mask = _mm512_cmp_ps_mask(flush, flush, _CMP_ORD_Q); + t = _mm512_mask_blend_epi32(mask, nan, t); + + // output.value = static_cast(input); + r.i = _mm512_cvtepi32_epi16(t); +#endif + + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) { + Packet16bf r; + r.i = ptrue(a.i); + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) { + Packet16bf r; + r.i = por(a.i, b.i); + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) { + Packet16bf r; + r.i = pxor(a.i, b.i); + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) { + Packet16bf r; + r.i = pand(a.i, b.i); + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a, + const Packet16bf& b) { + Packet16bf r; + r.i = pandnot(a.i, b.i); + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask, + const Packet16bf& a, + const Packet16bf& b) { + // Input mask is expected to be all 0/1, handle it with 8-bit + // intrinsic for performance. + Packet16bf r; + r.i = _mm256_blendv_epi8(b.i, a.i, mask.i); + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, + const Packet16bf& b) { + Packet16bf result; + result.i = Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); + return result; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a, + const Packet16bf& b) { + Packet16bf result; + result.i = Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); + return result; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a, + const Packet16bf& b) { + Packet16bf result; + result.i = Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); + return result; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a, + const Packet16bf& b) { + Packet16bf result; + result.i = Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); + return result; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) { + Packet16bf sign_mask; + sign_mask.i = _mm256_set1_epi16(static_cast(0x8000)); + Packet16bf result; + result.i = _mm256_xor_si256(a.i, sign_mask.i); + return result; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) { + return F32ToBf16(pabs(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf padd(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf psub(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmul(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pdiv(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmin(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmax(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux(const Packet16bf& p) { + return static_cast(predux(Bf16ToF32(p))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet16bf& from) { + return static_cast(predux_mul(Bf16ToF32(from))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet16bf& from) { + return static_cast(predux_min(Bf16ToF32(from))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet16bf& from) { + return static_cast(predux_max(Bf16ToF32(from))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) { + __m256i m = _mm256_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1, + 14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); + + Packet16bf res; + // Swap hi and lo first because shuffle is in 128-bit lanes. + res.i = _mm256_permute2x128_si256(a.i, a.i, 1); + // Shuffle 8-bit values in src within 2*128-bit lanes. + res.i = _mm256_shuffle_epi8(a.i, m); + return res; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pgather(const bfloat16* from, + Index stride) { + Packet16bf result; + result.i = _mm256_set_epi16( + from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value, + from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value, + from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, + from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value); + return result; +} + +template <> +EIGEN_STRONG_INLINE void pscatter(bfloat16* to, + const Packet16bf& from, + Index stride) { + EIGEN_ALIGN64 bfloat16 aux[16]; + pstore(aux, from); + to[stride*0].value = aux[0].value; + to[stride*1].value = aux[1].value; + to[stride*2].value = aux[2].value; + to[stride*3].value = aux[3].value; + to[stride*4].value = aux[4].value; + to[stride*5].value = aux[5].value; + to[stride*6].value = aux[6].value; + to[stride*7].value = aux[7].value; + to[stride*8].value = aux[8].value; + to[stride*9].value = aux[9].value; + to[stride*10].value = aux[10].value; + to[stride*11].value = aux[11].value; + to[stride*12].value = aux[12].value; + to[stride*13].value = aux[13].value; + to[stride*14].value = aux[14].value; + to[stride*15].value = aux[15].value; +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + __m256i a = kernel.packet[0].i; + __m256i b = kernel.packet[1].i; + __m256i c = kernel.packet[2].i; + __m256i d = kernel.packet[3].i; + __m256i e = kernel.packet[4].i; + __m256i f = kernel.packet[5].i; + __m256i g = kernel.packet[6].i; + __m256i h = kernel.packet[7].i; + __m256i i = kernel.packet[8].i; + __m256i j = kernel.packet[9].i; + __m256i k = kernel.packet[10].i; + __m256i l = kernel.packet[11].i; + __m256i m = kernel.packet[12].i; + __m256i n = kernel.packet[13].i; + __m256i o = kernel.packet[14].i; + __m256i p = kernel.packet[15].i; + + __m256i ab_07 = _mm256_unpacklo_epi16(a, b); + __m256i cd_07 = _mm256_unpacklo_epi16(c, d); + __m256i ef_07 = _mm256_unpacklo_epi16(e, f); + __m256i gh_07 = _mm256_unpacklo_epi16(g, h); + __m256i ij_07 = _mm256_unpacklo_epi16(i, j); + __m256i kl_07 = _mm256_unpacklo_epi16(k, l); + __m256i mn_07 = _mm256_unpacklo_epi16(m, n); + __m256i op_07 = _mm256_unpacklo_epi16(o, p); + + __m256i ab_8f = _mm256_unpackhi_epi16(a, b); + __m256i cd_8f = _mm256_unpackhi_epi16(c, d); + __m256i ef_8f = _mm256_unpackhi_epi16(e, f); + __m256i gh_8f = _mm256_unpackhi_epi16(g, h); + __m256i ij_8f = _mm256_unpackhi_epi16(i, j); + __m256i kl_8f = _mm256_unpackhi_epi16(k, l); + __m256i mn_8f = _mm256_unpackhi_epi16(m, n); + __m256i op_8f = _mm256_unpackhi_epi16(o, p); + + __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07); + __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07); + __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07); + __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07); + __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07); + __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07); + __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07); + __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07); + + __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f); + __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); + __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f); + __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f); + __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f); + __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f); + __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f); + __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f); + + __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03); + __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03); + __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03); + __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03); + __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47); + __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47); + __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47); + __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47); + __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b); + __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b); + __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b); + __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b); + __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf); + __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf); + __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf); + __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); + + // NOTE: no unpacklo/hi instr in this case, so using permute instr. + kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, + 0x20); + kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, + 0x20); + kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, + 0x20); + kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, + 0x20); + kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, + 0x20); + kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, + 0x20); + kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, + 0x20); + kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, + 0x20); + kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, + 0x20); + kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, + 0x20); + kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, + 0x20); + kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, + 0x20); + kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, + 0x20); + kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, + 0x20); + kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, + 0x20); + kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, + 0x20); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + __m256i a = kernel.packet[0].i; + __m256i b = kernel.packet[1].i; + __m256i c = kernel.packet[2].i; + __m256i d = kernel.packet[3].i; + + __m256i ab_07 = _mm256_unpacklo_epi16(a, b); + __m256i cd_07 = _mm256_unpacklo_epi16(c, d); + __m256i ab_8f = _mm256_unpackhi_epi16(a, b); + __m256i cd_8f = _mm256_unpackhi_epi16(c, d); + + __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07); + __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07); + __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f); + __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); + + // NOTE: no unpacklo/hi instr in this case, so using permute instr. + kernel.packet[0].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20); + kernel.packet[1].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20); + kernel.packet[2].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31); + kernel.packet[3].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31); +} } // end namespace internal diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index a82176941..e643b18a7 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -40,6 +40,32 @@ template<> EIGEN_STRONG_INLINE Packet16h pcast(const Packe return float2half(a); } +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet16f pcast(const Packet16bf& a) { + return Bf16ToF32(a); +} + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet16bf pcast(const Packet16f& a) { + return F32ToBf16(a); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h new file mode 100644 index 000000000..c3725d473 --- /dev/null +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -0,0 +1,703 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef EIGEN_BFLOAT16_H +#define EIGEN_BFLOAT16_H + +#if __cplusplus > 199711L +#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type() +#else +#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type() +#endif + +namespace Eigen { + +struct bfloat16; + +namespace bfloat16_impl { + +// Make our own __bfloat16_raw definition. +struct __bfloat16_raw { + EIGEN_DEVICE_FUNC __bfloat16_raw() : value(0) {} + explicit EIGEN_DEVICE_FUNC __bfloat16_raw(unsigned short raw) : value(raw) {} + unsigned short value; +}; + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value); +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff); +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h); + +struct bfloat16_base : public __bfloat16_raw { + EIGEN_DEVICE_FUNC bfloat16_base() {} + EIGEN_DEVICE_FUNC bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {} +}; + +} // namespace bfloat16_impl + +// Class definition. +struct bfloat16 : public bfloat16_impl::bfloat16_base { + + typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw; + + EIGEN_DEVICE_FUNC bfloat16() {} + + EIGEN_DEVICE_FUNC bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {} + + explicit EIGEN_DEVICE_FUNC bfloat16(bool b) + : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {} + template + explicit EIGEN_DEVICE_FUNC bfloat16(const T& val) + : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast(val))) {} + explicit EIGEN_DEVICE_FUNC bfloat16(float f) + : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {} + // Following the convention of numpy, converting between complex and + // float will lead to loss of imag value. + // Single precision complex. + typedef std::complex complex64; + // Double precision complex. + typedef std::complex complex128; + explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val) + : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(val.real())) {} + explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val) + : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast(val.real()))) {} + + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const { + // +0.0 and -0.0 become false, everything else becomes true. + return (value & 0x7fff) != 0; + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(signed char) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned char) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(short) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned short) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(int) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned int) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long long) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const { + return static_cast(bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const { + return bfloat16_impl::bfloat16_to_float(*this); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex64) const { + return complex64(bfloat16_impl::bfloat16_to_float(*this), float(0.0)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex128) const { + return complex128(static_cast(bfloat16_impl::bfloat16_to_float(*this)), double(0.0)); + } + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const { + return static_cast(bfloat16_impl::bfloat16_to_float(*this)); + } +}; + +} // end namespace Eigen + +namespace std { +template<> +struct numeric_limits { + static const bool is_specialized = true; + static const bool is_signed = true; + static const bool is_integer = false; + static const bool is_exact = false; + static const bool has_infinity = true; + static const bool has_quiet_NaN = true; + static const bool has_signaling_NaN = true; + static const float_denorm_style has_denorm = numeric_limits::has_denorm; + static const bool has_denorm_loss = numeric_limits::has_denorm_loss; + static const std::float_round_style round_style = numeric_limits::round_style; + static const bool is_iec559 = false; + static const bool is_bounded = true; + static const bool is_modulo = false; + static const int digits = 8; + static const int digits10 = 2; + static const int max_digits10 = 4; + static const int radix = 2; + static const int min_exponent = numeric_limits::min_exponent; + static const int min_exponent10 = numeric_limits::min_exponent10; + static const int max_exponent = numeric_limits::max_exponent; + static const int max_exponent10 = numeric_limits::max_exponent10; + static const bool traps = numeric_limits::traps; + static const bool tinyness_before = numeric_limits::tinyness_before; + + static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); } + static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); } + static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); } + static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); } + static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); } + static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); } + static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); } + static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); } + static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); } +}; + +// If std::numeric_limits is specialized, should also specialize +// std::numeric_limits, std::numeric_limits, and +// std::numeric_limits +// https://stackoverflow.com/a/16519653/ +template<> +struct numeric_limits : numeric_limits {}; +template<> +struct numeric_limits : numeric_limits {}; +template<> +struct numeric_limits : numeric_limits {}; +} // end namespace std + +namespace Eigen { + +namespace bfloat16_impl { + +// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler, +// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation +// of the functions, while the latter can only deal with one of them. +#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats + +#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC) +// We need to provide emulated *host-side* BF16 operators for clang. +#pragma push_macro("EIGEN_DEVICE_FUNC") +#undef EIGEN_DEVICE_FUNC +#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16) +#define EIGEN_DEVICE_FUNC __host__ +#else // both host and device need emulated ops. +#define EIGEN_DEVICE_FUNC __host__ __device__ +#endif +#endif + +// Definitions for CPUs, mostly working through conversion +// to/from fp32. + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) { + return bfloat16(float(a) + float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) { + return bfloat16(float(a) + static_cast(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) { + return bfloat16(static_cast(a) + float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) { + return bfloat16(float(a) * float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) { + return bfloat16(float(a) - float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) { + return bfloat16(float(a) / float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) { + bfloat16 result; + result.value = a.value ^ 0x8000; + return result; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) { + a = bfloat16(float(a) + float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) { + a = bfloat16(float(a) * float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) { + a = bfloat16(float(a) - float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) { + a = bfloat16(float(a) / float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) { + a += bfloat16(1); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) { + a -= bfloat16(1); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) { + bfloat16 original_value = a; + ++a; + return original_value; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) { + bfloat16 original_value = a; + --a; + return original_value; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) { + return numext::equal_strict(float(a),float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) { + return numext::not_equal_strict(float(a), float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) { + return float(a) < float(b); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) { + return float(a) <= float(b); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) { + return float(a) > float(b); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) { + return float(a) >= float(b); +} + +#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC) +#pragma pop_macro("EIGEN_DEVICE_FUNC") +#endif +#endif // Emulate support for bfloat16 floats + +// Division by an index. Do it in full float precision to avoid accuracy +// issues in converting the denominator to bfloat16. +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) { + return bfloat16(static_cast(a) / static_cast(b)); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) { + __bfloat16_raw output; + if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) { + output.value = 0x7FC0; + return output; + } else if (std::fabs(v) < std::numeric_limits::min EIGEN_NOT_A_MACRO()) { + // Flush denormal to +/- 0. + output.value = std::signbit(v) ? 0x8000 : 0; + return output; + } + const uint16_t* p = reinterpret_cast(&v); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + output.value = p[0]; +#else + output.value = p[1]; +#endif + return output; +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value) { + __bfloat16_raw h; + h.value = value; + return h; +} + +union float32_bits { + unsigned int u; + float f; +}; + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff) { +#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16)) + // Nothing to do here +#else + unsigned int input; + float32_bits f; + f.f = ff; + input = f.u; + __bfloat16_raw output; + + if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) { + // If the value is a NaN, squash it to a qNaN with msb of fraction set, + // this makes sure after truncation we don't end up with an inf. + // + // qNaN magic: All exponent bits set + most significant bit of fraction + // set. + output.value = 0x7fc0; + } else if (std::fabs(ff) < std::numeric_limits::min EIGEN_NOT_A_MACRO()) { + // Flush denormal to +/- 0.0 + output.value = std::signbit(ff) ? 0x8000 : 0; + } else { + // Fast rounding algorithm that rounds a half value to nearest even. This + // reduces expected error when we convert a large number of floats. Here + // is how it works: + // + // Definitions: + // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits + // with the following tags: + // + // Sign | Exp (8 bits) | Frac (23 bits) + // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT + // + // S: Sign bit. + // E: Exponent bits. + // F: First 6 bits of fraction. + // L: Least significant bit of resulting bfloat16 if we truncate away the + // rest of the float32. This is also the 7th bit of fraction + // R: Rounding bit, 8th bit of fraction. + // T: Sticky bits, rest of fraction, 15 bits. + // + // To round half to nearest even, there are 3 cases where we want to round + // down (simply truncate the result of the bits away, which consists of + // rounding bit and sticky bits) and two cases where we want to round up + // (truncate then add one to the result). + // + // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of + // 1s) as the rounding bias, adds the rounding bias to the input, then + // truncates the last 16 bits away. + // + // To understand how it works, we can analyze this algorithm case by case: + // + // 1. L = 0, R = 0: + // Expect: round down, this is less than half value. + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input may create any carry, depending on + // whether there is any value set to 1 in T bits. + // - R may be set to 1 if there is a carry. + // - L remains 0. + // - Note that this case also handles Inf and -Inf, where all fraction + // bits, including L, R and Ts are all 0. The output remains Inf after + // this algorithm. + // + // 2. L = 1, R = 0: + // Expect: round down, this is less than half value. + // + // Algorithm: + // - Rounding bias: 0x7fff + 1 = 0x8000 + // - Adding rounding bias to input doesn't change sticky bits but + // adds 1 to rounding bit. + // - L remains 1. + // + // 3. L = 0, R = 1, all of T are 0: + // Expect: round down, this is exactly at half, the result is already + // even (L=0). + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input sets all sticky bits to 1, but + // doesn't create a carry. + // - R remains 1. + // - L remains 0. + // + // 4. L = 1, R = 1: + // Expect: round up, this is exactly at half, the result needs to be + // round to the next even number. + // + // Algorithm: + // - Rounding bias: 0x7fff + 1 = 0x8000 + // - Adding rounding bias to input doesn't change sticky bits, but + // creates a carry from rounding bit. + // - The carry sets L to 0, creates another carry bit and propagate + // forward to F bits. + // - If all the F bits are 1, a carry then propagates to the exponent + // bits, which then creates the minimum value with the next exponent + // value. Note that we won't have the case where exponents are all 1, + // since that's either a NaN (handled in the other if condition) or inf + // (handled in case 1). + // + // 5. L = 0, R = 1, any of T is 1: + // Expect: round up, this is greater than half. + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input creates a carry from sticky bits, + // sets rounding bit to 0, then create another carry. + // - The second carry sets L to 1. + // + // Examples: + // + // Exact half value that is already even: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000 + // + // This falls into case 3. We truncate the rest of 16 bits and no + // carry is created into F and L: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 + // + // Exact half value, round to next even number: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000 + // + // This falls into case 4. We create a carry from R and T, + // which then propagates into L and F: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 + // + // + // Max denormal value round to min normal value: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111 + // + // This falls into case 4. We create a carry from R and T, + // propagate into L and F, which then propagates into exponent + // bits: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 + // + // Max normal value round to Inf: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111 + // + // This falls into case 4. We create a carry from R and T, + // propagate into L and F, which then propagates into exponent + // bits: + // + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 + // + // + // Least significant bit of resulting bfloat. + unsigned int lsb = (input >> 16) & 1; + unsigned int rounding_bias = 0x7fff + lsb; + input += rounding_bias; + output.value = static_cast(input >> 16); + } + return output; +#endif +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) { + float result = 0; + unsigned short* q = reinterpret_cast(&result); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = h.value; +#else + q[1] = h.value; +#endif + return result; +} +// --- standard functions --- + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) { + return std::isinf EIGEN_NOT_A_MACRO(float(a)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) { + return std::isnan EIGEN_NOT_A_MACRO(float(a)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) { + return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a)); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) { + bfloat16 result; + result.value = a.value & 0x7FFF; + return result; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) { + return bfloat16(::expf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) { + return bfloat16(numext::expm1(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) { + return bfloat16(::logf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) { + return bfloat16(numext::log1p(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) { + return bfloat16(::log10f(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) { + return bfloat16(::sqrtf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) { + return bfloat16(::powf(float(a), float(b))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) { + return bfloat16(::sinf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) { + return bfloat16(::cosf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) { + return bfloat16(::tanf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) { + return bfloat16(::asinf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) { + return bfloat16(::acosf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) { + return bfloat16(::atanf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) { + return bfloat16(::sinhf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) { + return bfloat16(::coshf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) { + return bfloat16(::tanhf(float(a))); +} +#if EIGEN_HAS_CXX11_MATH +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) { + return bfloat16(::asinh(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) { + return bfloat16(::acosh(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) { + return bfloat16(::atanh(float(a))); +} +#endif +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) { + return bfloat16(::floorf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) { + return bfloat16(::ceilf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) { + return bfloat16(::fmodf(float(a), float(b))); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) { + const float f1 = static_cast(a); + const float f2 = static_cast(b); + return f2 < f1 ? b : a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) { + const float f1 = static_cast(a); + const float f2 = static_cast(b); + return f1 < f2 ? b : a; +} + +#ifndef EIGEN_NO_IO +EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) { + os << static_cast(v); + return os; +} +#endif + +} // end namespace bfloat16_impl + +namespace internal { + +template<> +struct random_default_impl +{ + static inline bfloat16 run(const bfloat16& x, const bfloat16& y) + { + return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX)); + } + static inline bfloat16 run() + { + return run(bfloat16(-1.f), bfloat16(1.f)); + } +}; + +template<> struct is_arithmetic { enum { value = true }; }; + +} // end namespace internal + +template<> struct NumTraits + : GenericNumTraits +{ + enum { + IsSigned = true, + IsInteger = false, + IsComplex = false, + RequireInitialization = false + }; + + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() { + return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); + } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() { return Eigen::bfloat16(5e-2f); } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() { + return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F); + } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() { + return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F); + } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() { + return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); + } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() { + return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); + } +}; + +} // end namespace Eigen + +namespace std { + +#if __cplusplus > 199711L +template <> +struct hash { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const { + return hash()(static_cast(a)); + } +}; +#endif + +} // end namespace std + + +namespace Eigen { +namespace numext { + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +bool (isnan)(const Eigen::bfloat16& h) { + return (bfloat16_impl::isnan)(h); +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +bool (isinf)(const Eigen::bfloat16& h) { + return (bfloat16_impl::isinf)(h); +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +bool (isfinite)(const Eigen::bfloat16& h) { + return (bfloat16_impl::isfinite)(h); +} + +} // namespace Eigen +} // namespace numext + +#endif // EIGEN_BFLOAT16_H diff --git a/Eigen/src/Core/arch/Default/TypeCasting.h b/Eigen/src/Core/arch/Default/TypeCasting.h index b6df98468..fb8183b78 100644 --- a/Eigen/src/Core/arch/Default/TypeCasting.h +++ b/Eigen/src/Core/arch/Default/TypeCasting.h @@ -71,6 +71,49 @@ template<> struct functor_traits > { enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + +template<> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef Eigen::bfloat16 result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const { + return Eigen::bfloat16(a); + } +}; + +template<> +struct functor_traits > +{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + + +template<> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef Eigen::bfloat16 result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const { + return Eigen::bfloat16(static_cast(a)); + } +}; + +template<> +struct functor_traits > +{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + + +template<> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef float result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const { + return static_cast(a); + } +}; + +template<> +struct functor_traits > +{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + + } } diff --git a/Eigen/src/Core/util/ConfigureVectorization.h b/Eigen/src/Core/util/ConfigureVectorization.h index 952abc306..739dab60d 100644 --- a/Eigen/src/Core/util/ConfigureVectorization.h +++ b/Eigen/src/Core/util/ConfigureVectorization.h @@ -288,6 +288,9 @@ #ifdef __AVX512ER__ #define EIGEN_VECTORIZE_AVX512ER #endif + #ifdef __AVX512BF16__ + #define EIGEN_VECTORIZE_AVX512BF16 + #endif #endif #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 06c7144ee..b0950d081 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -286,6 +286,7 @@ ei_add_test(ctorleak) ei_add_test(mpl2only) ei_add_test(inplace_decomposition) ei_add_test(half_float) +ei_add_test(bfloat16_float) ei_add_test(array_of_string) ei_add_test(num_dimensions) ei_add_test(stl_iterators) diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp new file mode 100644 index 000000000..eb55f7d45 --- /dev/null +++ b/test/bfloat16_float.cpp @@ -0,0 +1,399 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include +#include + +#include "main.h" + +#include + +// Make sure it's possible to forward declare Eigen::bfloat16 +namespace Eigen { +struct bfloat16; +} + +using Eigen::bfloat16; + +float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa, + uint32_t low_mantissa) { + float dest; + uint32_t src = (sign << 31) + (exponent << 23) + (high_mantissa << 16) + low_mantissa; + memcpy(static_cast(&dest), + static_cast(&src), sizeof(dest)); + return dest; +} + +void test_truncate(float input, float expected_truncation, float expected_rounding){ + bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input); + bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne(input); + if ((numext::isnan)(input)){ + VERIFY((numext::isnan)(static_cast(truncated)) || (numext::isinf)(static_cast(truncated))); + VERIFY((numext::isnan)(static_cast(rounded)) || (numext::isinf)(static_cast(rounded))); + return; + } + VERIFY_IS_EQUAL(expected_truncation, static_cast(truncated)); + VERIFY_IS_EQUAL(expected_rounding, static_cast(rounded)); +} + +void test_conversion() +{ + using Eigen::bfloat16_impl::__bfloat16_raw; + + // Conversion from float. + VERIFY_IS_EQUAL(bfloat16(1.0f).value, 0x3f80); + VERIFY_IS_EQUAL(bfloat16(0.5f).value, 0x3f00); + VERIFY_IS_EQUAL(bfloat16(0.33333f).value, 0x3eab); + VERIFY_IS_EQUAL(bfloat16(3.38e38f).value, 0x7f7e); + VERIFY_IS_EQUAL(bfloat16(3.40e38f).value, 0x7f80); // Becomes infinity. + + // Verify round-to-nearest-even behavior. + float val1 = static_cast(bfloat16(__bfloat16_raw(0x3c00))); + float val2 = static_cast(bfloat16(__bfloat16_raw(0x3c01))); + float val3 = static_cast(bfloat16(__bfloat16_raw(0x3c02))); + VERIFY_IS_EQUAL(bfloat16(0.5f * (val1 + val2)).value, 0x3c00); + VERIFY_IS_EQUAL(bfloat16(0.5f * (val2 + val3)).value, 0x3c02); + + // Conversion from int. + VERIFY_IS_EQUAL(bfloat16(-1).value, 0xbf80); + VERIFY_IS_EQUAL(bfloat16(0).value, 0x0000); + VERIFY_IS_EQUAL(bfloat16(1).value, 0x3f80); + VERIFY_IS_EQUAL(bfloat16(2).value, 0x4000); + VERIFY_IS_EQUAL(bfloat16(3).value, 0x4040); + VERIFY_IS_EQUAL(bfloat16(12).value, 0x4140); + + // Conversion from bool. + VERIFY_IS_EQUAL(bfloat16(false).value, 0x0000); + VERIFY_IS_EQUAL(bfloat16(true).value, 0x3f80); + + // Conversion to float. + VERIFY_IS_EQUAL(static_cast(bfloat16(__bfloat16_raw(0x0000))), 0.0f); + VERIFY_IS_EQUAL(static_cast(bfloat16(__bfloat16_raw(0x3f80))), 1.0f); + + // Zero representations + VERIFY_IS_EQUAL(bfloat16(0.0f), bfloat16(0.0f)); + VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(0.0f)); + VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(-0.0f)); + VERIFY_IS_EQUAL(bfloat16(0.0f).value, 0x0000); + VERIFY_IS_EQUAL(bfloat16(-0.0f).value, 0x8000); + + // Flush denormals to zero + for (float denorm = -std::numeric_limits::denorm_min(); + denorm < std::numeric_limits::denorm_min(); + denorm = nextafterf(denorm, 1.0f)) { + bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm); + VERIFY_IS_EQUAL(static_cast(bf_trunc), 0.0f); + if (std::signbit(denorm)) { + VERIFY_IS_EQUAL(bf_trunc.value, 0x8000); + } else { + VERIFY_IS_EQUAL(bf_trunc.value, 0x0000); + } + bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm); + VERIFY_IS_EQUAL(static_cast(bf_round), 0.0f); + if (std::signbit(denorm)) { + VERIFY_IS_EQUAL(bf_round.value, 0x8000); + } else { + VERIFY_IS_EQUAL(bf_round.value, 0x0000); + } + } + + // Default is zero + VERIFY_IS_EQUAL(static_cast(bfloat16()), 0.0f); + + // Representable floats round trip via bfloat16 + VERIFY_IS_EQUAL(static_cast(static_cast(-std::numeric_limits::infinity())), -std::numeric_limits::infinity()); + VERIFY_IS_EQUAL(static_cast(static_cast(std::numeric_limits::infinity())), std::numeric_limits::infinity()); + VERIFY_IS_EQUAL(static_cast(static_cast(-1.0f)), -1.0f); + VERIFY_IS_EQUAL(static_cast(static_cast(-0.5f)), -0.5f); + VERIFY_IS_EQUAL(static_cast(static_cast(-0.0f)), -0.0f); + VERIFY_IS_EQUAL(static_cast(static_cast(1.0f)), 1.0f); + VERIFY_IS_EQUAL(static_cast(static_cast(0.5f)), 0.5f); + VERIFY_IS_EQUAL(static_cast(static_cast(0.0f)), 0.0f); + + // Truncate test + test_truncate( + BinaryToFloat(0, 0x80, 0x48, 0xf5c3), + BinaryToFloat(0, 0x80, 0x48, 0x0000), + BinaryToFloat(0, 0x80, 0x49, 0x0000)); + test_truncate( + BinaryToFloat(1, 0x80, 0x48, 0xf5c3), + BinaryToFloat(1, 0x80, 0x48, 0x0000), + BinaryToFloat(1, 0x80, 0x49, 0x0000)); + test_truncate( + BinaryToFloat(0, 0x80, 0x48, 0x8000), + BinaryToFloat(0, 0x80, 0x48, 0x0000), + BinaryToFloat(0, 0x80, 0x48, 0x0000)); + test_truncate( + BinaryToFloat(0, 0xff, 0x00, 0x0001), + BinaryToFloat(0, 0xff, 0x40, 0x0000), + BinaryToFloat(0, 0xff, 0x40, 0x0000)); + test_truncate( + BinaryToFloat(0, 0xff, 0x7f, 0xffff), + BinaryToFloat(0, 0xff, 0x40, 0x0000), + BinaryToFloat(0, 0xff, 0x40, 0x0000)); + test_truncate( + BinaryToFloat(1, 0x80, 0x48, 0xc000), + BinaryToFloat(1, 0x80, 0x48, 0x0000), + BinaryToFloat(1, 0x80, 0x49, 0x0000)); + test_truncate( + BinaryToFloat(0, 0x80, 0x48, 0x0000), + BinaryToFloat(0, 0x80, 0x48, 0x0000), + BinaryToFloat(0, 0x80, 0x48, 0x0000)); + test_truncate( + BinaryToFloat(0, 0x80, 0x48, 0x4000), + BinaryToFloat(0, 0x80, 0x48, 0x0000), + BinaryToFloat(0, 0x80, 0x48, 0x0000)); + test_truncate( + BinaryToFloat(0, 0x80, 0x48, 0x8000), + BinaryToFloat(0, 0x80, 0x48, 0x0000), + BinaryToFloat(0, 0x80, 0x48, 0x0000)); + test_truncate( + BinaryToFloat(0, 0x00, 0x48, 0x8000), + BinaryToFloat(0, 0x00, 0x00, 0x0000), + BinaryToFloat(0, 0x00, 0x00, 0x0000)); + test_truncate( + BinaryToFloat(0, 0x00, 0x7f, 0xc000), + BinaryToFloat(0, 0x00, 0x00, 0x0000), + BinaryToFloat(0, 0x00, 0x00, 0x0000)); + + // Conversion + Array a; + for (int i = 0; i < 100; i++) a(i) = i + 1.25; + Array b = a.cast(); + Array c = b.cast(); + for (int i = 0; i < 100; ++i) { + VERIFY_LE(numext::abs(c(i) - a(i)), a(i) / 128); + } + + // Epsilon + VERIFY_LE(1.0f, static_cast((std::numeric_limits::epsilon)() + bfloat16(1.0f))); + VERIFY_IS_EQUAL(1.0f, static_cast((std::numeric_limits::epsilon)() / bfloat16(2.0f) + bfloat16(1.0f))); + + // Negate + VERIFY_IS_EQUAL(static_cast(-bfloat16(3.0f)), -3.0f); + VERIFY_IS_EQUAL(static_cast(-bfloat16(-4.5f)), 4.5f); + + +#if !EIGEN_COMP_MSVC + // Visual Studio errors out on divisions by 0 + VERIFY((numext::isnan)(static_cast(bfloat16(0.0 / 0.0)))); + VERIFY((numext::isinf)(static_cast(bfloat16(1.0 / 0.0)))); + VERIFY((numext::isinf)(static_cast(bfloat16(-1.0 / 0.0)))); + + // Visual Studio errors out on divisions by 0 + VERIFY((numext::isnan)(bfloat16(0.0 / 0.0))); + VERIFY((numext::isinf)(bfloat16(1.0 / 0.0))); + VERIFY((numext::isinf)(bfloat16(-1.0 / 0.0))); +#endif + + // NaNs and infinities. + VERIFY(!(numext::isinf)(static_cast(bfloat16(3.38e38f)))); // Largest finite number. + VERIFY(!(numext::isnan)(static_cast(bfloat16(0.0f)))); + VERIFY((numext::isinf)(static_cast(bfloat16(__bfloat16_raw(0xff80))))); + VERIFY((numext::isnan)(static_cast(bfloat16(__bfloat16_raw(0xffc0))))); + VERIFY((numext::isinf)(static_cast(bfloat16(__bfloat16_raw(0x7f80))))); + VERIFY((numext::isnan)(static_cast(bfloat16(__bfloat16_raw(0x7fc0))))); + + // Exactly same checks as above, just directly on the bfloat16 representation. + VERIFY(!(numext::isinf)(bfloat16(__bfloat16_raw(0x7bff)))); + VERIFY(!(numext::isnan)(bfloat16(__bfloat16_raw(0x0000)))); + VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0xff80)))); + VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0xffc0)))); + VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0x7f80)))); + VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0x7fc0)))); +} + +void test_numtraits() +{ + std::cout << "epsilon = " << NumTraits::epsilon() << " (0x" << std::hex << NumTraits::epsilon().value << ")" << std::endl; + std::cout << "highest = " << NumTraits::highest() << " (0x" << std::hex << NumTraits::highest().value << ")" << std::endl; + std::cout << "lowest = " << NumTraits::lowest() << " (0x" << std::hex << NumTraits::lowest().value << ")" << std::endl; + std::cout << "min = " << (std::numeric_limits::min)() << " (0x" << std::hex << (std::numeric_limits::min)().value << ")" << std::endl; + std::cout << "denorm min = " << (std::numeric_limits::denorm_min)() << " (0x" << std::hex << (std::numeric_limits::denorm_min)().value << ")" << std::endl; + std::cout << "infinity = " << NumTraits::infinity() << " (0x" << std::hex << NumTraits::infinity().value << ")" << std::endl; + std::cout << "quiet nan = " << NumTraits::quiet_NaN() << " (0x" << std::hex << NumTraits::quiet_NaN().value << ")" << std::endl; + std::cout << "signaling nan = " << std::numeric_limits::signaling_NaN() << " (0x" << std::hex << std::numeric_limits::signaling_NaN().value << ")" << std::endl; + + VERIFY(NumTraits::IsSigned); + + VERIFY_IS_EQUAL( std::numeric_limits::infinity().value, bfloat16(std::numeric_limits::infinity()).value ); + VERIFY_IS_EQUAL( std::numeric_limits::quiet_NaN().value, bfloat16(std::numeric_limits::quiet_NaN()).value ); + VERIFY( (std::numeric_limits::min)() > bfloat16(0.f) ); + VERIFY( (std::numeric_limits::denorm_min)() > bfloat16(0.f) ); + VERIFY_IS_EQUAL( (std::numeric_limits::denorm_min)()/bfloat16(2), bfloat16(0.f) ); +} + +void test_arithmetic() +{ + VERIFY_IS_EQUAL(static_cast(bfloat16(2) + bfloat16(2)), 4); + VERIFY_IS_EQUAL(static_cast(bfloat16(2) + bfloat16(-2)), 0); + VERIFY_IS_APPROX(static_cast(bfloat16(0.33333f) + bfloat16(0.66667f)), 1.0f); + VERIFY_IS_EQUAL(static_cast(bfloat16(2.0f) * bfloat16(-5.5f)), -11.0f); + VERIFY_IS_APPROX(static_cast(bfloat16(1.0f) / bfloat16(3.0f)), 0.3339f); + VERIFY_IS_EQUAL(static_cast(-bfloat16(4096.0f)), -4096.0f); + VERIFY_IS_EQUAL(static_cast(-bfloat16(-4096.0f)), 4096.0f); +} + +void test_comparison() +{ + VERIFY(bfloat16(1.0f) > bfloat16(0.5f)); + VERIFY(bfloat16(0.5f) < bfloat16(1.0f)); + VERIFY(!(bfloat16(1.0f) < bfloat16(0.5f))); + VERIFY(!(bfloat16(0.5f) > bfloat16(1.0f))); + + VERIFY(!(bfloat16(4.0f) > bfloat16(4.0f))); + VERIFY(!(bfloat16(4.0f) < bfloat16(4.0f))); + + VERIFY(!(bfloat16(0.0f) < bfloat16(-0.0f))); + VERIFY(!(bfloat16(-0.0f) < bfloat16(0.0f))); + VERIFY(!(bfloat16(0.0f) > bfloat16(-0.0f))); + VERIFY(!(bfloat16(-0.0f) > bfloat16(0.0f))); + + VERIFY(bfloat16(0.2f) > bfloat16(-1.0f)); + VERIFY(bfloat16(-1.0f) < bfloat16(0.2f)); + VERIFY(bfloat16(-16.0f) < bfloat16(-15.0f)); + + VERIFY(bfloat16(1.0f) == bfloat16(1.0f)); + VERIFY(bfloat16(1.0f) != bfloat16(2.0f)); + + // Comparisons with NaNs and infinities. +#if !EIGEN_COMP_MSVC + // Visual Studio errors out on divisions by 0 + VERIFY(!(bfloat16(0.0 / 0.0) == bfloat16(0.0 / 0.0))); + VERIFY(bfloat16(0.0 / 0.0) != bfloat16(0.0 / 0.0)); + + VERIFY(!(bfloat16(1.0) == bfloat16(0.0 / 0.0))); + VERIFY(!(bfloat16(1.0) < bfloat16(0.0 / 0.0))); + VERIFY(!(bfloat16(1.0) > bfloat16(0.0 / 0.0))); + VERIFY(bfloat16(1.0) != bfloat16(0.0 / 0.0)); + + VERIFY(bfloat16(1.0) < bfloat16(1.0 / 0.0)); + VERIFY(bfloat16(1.0) > bfloat16(-1.0 / 0.0)); +#endif +} + +void test_basic_functions() +{ + VERIFY_IS_EQUAL(static_cast(numext::abs(bfloat16(3.5f))), 3.5f); + VERIFY_IS_EQUAL(static_cast(abs(bfloat16(3.5f))), 3.5f); + VERIFY_IS_EQUAL(static_cast(numext::abs(bfloat16(-3.5f))), 3.5f); + VERIFY_IS_EQUAL(static_cast(abs(bfloat16(-3.5f))), 3.5f); + + VERIFY_IS_EQUAL(static_cast(numext::floor(bfloat16(3.5f))), 3.0f); + VERIFY_IS_EQUAL(static_cast(floor(bfloat16(3.5f))), 3.0f); + VERIFY_IS_EQUAL(static_cast(numext::floor(bfloat16(-3.5f))), -4.0f); + VERIFY_IS_EQUAL(static_cast(floor(bfloat16(-3.5f))), -4.0f); + + VERIFY_IS_EQUAL(static_cast(numext::ceil(bfloat16(3.5f))), 4.0f); + VERIFY_IS_EQUAL(static_cast(ceil(bfloat16(3.5f))), 4.0f); + VERIFY_IS_EQUAL(static_cast(numext::ceil(bfloat16(-3.5f))), -3.0f); + VERIFY_IS_EQUAL(static_cast(ceil(bfloat16(-3.5f))), -3.0f); + + VERIFY_IS_APPROX(static_cast(numext::sqrt(bfloat16(0.0f))), 0.0f); + VERIFY_IS_APPROX(static_cast(sqrt(bfloat16(0.0f))), 0.0f); + VERIFY_IS_APPROX(static_cast(numext::sqrt(bfloat16(4.0f))), 2.0f); + VERIFY_IS_APPROX(static_cast(sqrt(bfloat16(4.0f))), 2.0f); + + VERIFY_IS_APPROX(static_cast(numext::pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f); + VERIFY_IS_APPROX(static_cast(pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f); + VERIFY_IS_APPROX(static_cast(numext::pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f); + VERIFY_IS_APPROX(static_cast(pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f); + + VERIFY_IS_EQUAL(static_cast(numext::exp(bfloat16(0.0f))), 1.0f); + VERIFY_IS_EQUAL(static_cast(exp(bfloat16(0.0f))), 1.0f); + VERIFY_IS_APPROX(static_cast(numext::exp(bfloat16(EIGEN_PI))), 20.f + static_cast(EIGEN_PI)); + VERIFY_IS_APPROX(static_cast(exp(bfloat16(EIGEN_PI))), 20.f + static_cast(EIGEN_PI)); + + VERIFY_IS_EQUAL(static_cast(numext::expm1(bfloat16(0.0f))), 0.0f); + VERIFY_IS_EQUAL(static_cast(expm1(bfloat16(0.0f))), 0.0f); + VERIFY_IS_APPROX(static_cast(numext::expm1(bfloat16(2.0f))), 6.375f); + VERIFY_IS_APPROX(static_cast(expm1(bfloat16(2.0f))), 6.375f); + + VERIFY_IS_EQUAL(static_cast(numext::log(bfloat16(1.0f))), 0.0f); + VERIFY_IS_EQUAL(static_cast(log(bfloat16(1.0f))), 0.0f); + VERIFY_IS_APPROX(static_cast(numext::log(bfloat16(10.0f))), 2.296875f); + VERIFY_IS_APPROX(static_cast(log(bfloat16(10.0f))), 2.296875f); + + VERIFY_IS_EQUAL(static_cast(numext::log1p(bfloat16(0.0f))), 0.0f); + VERIFY_IS_EQUAL(static_cast(log1p(bfloat16(0.0f))), 0.0f); + VERIFY_IS_APPROX(static_cast(numext::log1p(bfloat16(10.0f))), 2.390625f); + VERIFY_IS_APPROX(static_cast(log1p(bfloat16(10.0f))), 2.390625f); +} + +void test_trigonometric_functions() +{ + VERIFY_IS_APPROX(numext::cos(bfloat16(0.0f)), bfloat16(cosf(0.0f))); + VERIFY_IS_APPROX(cos(bfloat16(0.0f)), bfloat16(cosf(0.0f))); + VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI)), bfloat16(cosf(EIGEN_PI))); + // VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI/2)), bfloat16(cosf(EIGEN_PI/2))); + // VERIFY_IS_APPROX(numext::cos(bfloat16(3*EIGEN_PI/2)), bfloat16(cosf(3*EIGEN_PI/2))); + VERIFY_IS_APPROX(numext::cos(bfloat16(3.5f)), bfloat16(cosf(3.5f))); + + VERIFY_IS_APPROX(numext::sin(bfloat16(0.0f)), bfloat16(sinf(0.0f))); + VERIFY_IS_APPROX(sin(bfloat16(0.0f)), bfloat16(sinf(0.0f))); + // VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI)), bfloat16(sinf(EIGEN_PI))); + VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI/2)), bfloat16(sinf(EIGEN_PI/2))); + VERIFY_IS_APPROX(numext::sin(bfloat16(3*EIGEN_PI/2)), bfloat16(sinf(3*EIGEN_PI/2))); + VERIFY_IS_APPROX(numext::sin(bfloat16(3.5f)), bfloat16(sinf(3.5f))); + + VERIFY_IS_APPROX(numext::tan(bfloat16(0.0f)), bfloat16(tanf(0.0f))); + VERIFY_IS_APPROX(tan(bfloat16(0.0f)), bfloat16(tanf(0.0f))); + // VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI)), bfloat16(tanf(EIGEN_PI))); + // VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI/2)), bfloat16(tanf(EIGEN_PI/2))); + // VERIFY_IS_APPROX(numext::tan(bfloat16(3*EIGEN_PI/2)), bfloat16(tanf(3*EIGEN_PI/2))); + VERIFY_IS_APPROX(numext::tan(bfloat16(3.5f)), bfloat16(tanf(3.5f))); +} + +void test_array() +{ + typedef Array ArrayXh; + Index size = internal::random(1,10); + Index i = internal::random(0,size-1); + ArrayXh a1 = ArrayXh::Random(size), a2 = ArrayXh::Random(size); + VERIFY_IS_APPROX( a1+a1, bfloat16(2)*a1 ); + VERIFY( (a1.abs() >= bfloat16(0)).all() ); + VERIFY_IS_APPROX( (a1*a1).sqrt(), a1.abs() ); + + VERIFY( ((a1.min)(a2) <= (a1.max)(a2)).all() ); + a1(i) = bfloat16(-10.); + VERIFY_IS_EQUAL( a1.minCoeff(), bfloat16(-10.) ); + a1(i) = bfloat16(10.); + VERIFY_IS_EQUAL( a1.maxCoeff(), bfloat16(10.) ); + + std::stringstream ss; + ss << a1; +} + +void test_product() +{ + typedef Matrix MatrixXh; + Index rows = internal::random(1,EIGEN_TEST_MAX_SIZE); + Index cols = internal::random(1,EIGEN_TEST_MAX_SIZE); + Index depth = internal::random(1,EIGEN_TEST_MAX_SIZE); + MatrixXh Ah = MatrixXh::Random(rows,depth); + MatrixXh Bh = MatrixXh::Random(depth,cols); + MatrixXh Ch = MatrixXh::Random(rows,cols); + MatrixXf Af = Ah.cast(); + MatrixXf Bf = Bh.cast(); + MatrixXf Cf = Ch.cast(); + VERIFY_IS_APPROX(Ch.noalias()+=Ah*Bh, (Cf.noalias()+=Af*Bf).cast()); +} + +EIGEN_DECLARE_TEST(bfloat16_float) +{ + CALL_SUBTEST(test_numtraits()); + for(int i = 0; i < g_repeat; i++) { + CALL_SUBTEST(test_conversion()); + CALL_SUBTEST(test_arithmetic()); + CALL_SUBTEST(test_comparison()); + CALL_SUBTEST(test_basic_functions()); + CALL_SUBTEST(test_trigonometric_functions()); + CALL_SUBTEST(test_array()); + CALL_SUBTEST(test_product()); + } +} diff --git a/test/main.h b/test/main.h index 54553f742..19e6f959d 100644 --- a/test/main.h +++ b/test/main.h @@ -435,6 +435,7 @@ EIGEN_TEST_SCALAR_TEST_OVERLOAD(unsigned long long) EIGEN_TEST_SCALAR_TEST_OVERLOAD(float) EIGEN_TEST_SCALAR_TEST_OVERLOAD(double) EIGEN_TEST_SCALAR_TEST_OVERLOAD(half) +EIGEN_TEST_SCALAR_TEST_OVERLOAD(bfloat16) #undef EIGEN_TEST_SCALAR_TEST_OVERLOAD diff --git a/test/numext.cpp b/test/numext.cpp index 8c6447d40..ff4d13ff3 100644 --- a/test/numext.cpp +++ b/test/numext.cpp @@ -45,6 +45,7 @@ EIGEN_DECLARE_TEST(numext) { CALL_SUBTEST( check_abs() ); CALL_SUBTEST( check_abs() ); CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); CALL_SUBTEST( check_abs() ); CALL_SUBTEST( check_abs() ); CALL_SUBTEST( check_abs() ); diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 64ba28741..a82b2b87a 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -836,6 +836,7 @@ EIGEN_DECLARE_TEST(packetmath) #ifdef EIGEN_PACKET_MATH_SSE_H CALL_SUBTEST_14(( packetmath::type>() )); #endif + CALL_SUBTEST_15(( packetmath::type>() )); g_first_pass = false; } } diff --git a/test/packetmath_test_shared.h b/test/packetmath_test_shared.h index 04f719f96..5be10997a 100644 --- a/test/packetmath_test_shared.h +++ b/test/packetmath_test_shared.h @@ -50,6 +50,7 @@ T apply_bit_op(Bits a, Bits b, Func f) { EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,float) \ EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,double) \ EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,half) \ + EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,bfloat16) \ EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,std::complex) \ EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,std::complex) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h index 445248163..ea286fee1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h @@ -101,6 +101,17 @@ Eigen::half RandomToTypeUniform(uint64_t* state, uint64_t stream) { return result - Eigen::half(1.0f); } +template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +Eigen::bfloat16 RandomToTypeUniform(uint64_t* state, uint64_t stream) { + Eigen::bfloat16 result; + // Generate 7 random bits for the mantissa + unsigned rnd = PCG_XSH_RS_generator(state, stream); + result.value = static_cast(rnd & 0x7fu); + // Set the exponent + result.value |= (static_cast(127) << 7); + // Return the final result + return result - Eigen::bfloat16(1.0f); +} template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float RandomToTypeUniform(uint64_t* state, uint64_t stream) { diff --git a/unsupported/Eigen/SpecialFunctions b/unsupported/Eigen/SpecialFunctions index a098ce871..dda6618de 100644 --- a/unsupported/Eigen/SpecialFunctions +++ b/unsupported/Eigen/SpecialFunctions @@ -62,6 +62,7 @@ namespace Eigen { #include "src/SpecialFunctions/BesselFunctionsImpl.h" #include "src/SpecialFunctions/BesselFunctionsPacketMath.h" +#include "src/SpecialFunctions/BesselFunctionsBFloat16.h" #include "src/SpecialFunctions/BesselFunctionsHalf.h" #include "src/SpecialFunctions/BesselFunctionsFunctors.h" #include "src/SpecialFunctions/BesselFunctionsArrayAPI.h" @@ -70,6 +71,7 @@ namespace Eigen { #include "src/SpecialFunctions/HipVectorCompatibility.h" #endif #include "src/SpecialFunctions/SpecialFunctionsPacketMath.h" +#include "src/SpecialFunctions/SpecialFunctionsBFloat16.h" #include "src/SpecialFunctions/SpecialFunctionsHalf.h" #include "src/SpecialFunctions/SpecialFunctionsFunctors.h" #include "src/SpecialFunctions/SpecialFunctionsArrayAPI.h" diff --git a/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsBFloat16.h b/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsBFloat16.h new file mode 100644 index 000000000..6049cc2fe --- /dev/null +++ b/unsupported/Eigen/src/SpecialFunctions/BesselFunctionsBFloat16.h @@ -0,0 +1,68 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_BESSELFUNCTIONS_BFLOAT16_H +#define EIGEN_BESSELFUNCTIONS_BFLOAT16_H + +namespace Eigen { +namespace numext { + +#if EIGEN_HAS_C99_MATH +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i0(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_i0(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i0e(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_i0e(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i1(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_i1(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i1e(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_i1e(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_j0(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_j0(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_j1(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_j1(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_y0(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_y0(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_y1(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_y1(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k0(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_k0(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k0e(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_k0e(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k1(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_k1(static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k1e(const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::bessel_k1e(static_cast(x))); +} +#endif + +} // end namespace numext +} // end namespace Eigen + +#endif // EIGEN_BESSELFUNCTIONS_BFLOAT16_H diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsBFloat16.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsBFloat16.h new file mode 100644 index 000000000..2d94231f0 --- /dev/null +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsBFloat16.h @@ -0,0 +1,58 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_SPECIALFUNCTIONS_BFLOAT16_H +#define EIGEN_SPECIALFUNCTIONS_BFLOAT16_H + +namespace Eigen { +namespace numext { + +#if EIGEN_HAS_C99_MATH +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 lgamma(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::lgamma(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 digamma(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::digamma(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 zeta(const Eigen::bfloat16& x, const Eigen::bfloat16& q) { + return Eigen::bfloat16(Eigen::numext::zeta(static_cast(x), static_cast(q))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 polygamma(const Eigen::bfloat16& n, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::polygamma(static_cast(n), static_cast(x))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 erf(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::erf(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 erfc(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::erfc(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 ndtri(const Eigen::bfloat16& a) { + return Eigen::bfloat16(Eigen::numext::ndtri(static_cast(a))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igamma(const Eigen::bfloat16& a, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::igamma(static_cast(a), static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igamma_der_a(const Eigen::bfloat16& a, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::igamma_der_a(static_cast(a), static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 gamma_sample_der_alpha(const Eigen::bfloat16& alpha, const Eigen::bfloat16& sample) { + return Eigen::bfloat16(Eigen::numext::gamma_sample_der_alpha(static_cast(alpha), static_cast(sample))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igammac(const Eigen::bfloat16& a, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::igammac(static_cast(a), static_cast(x))); +} +template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 betainc(const Eigen::bfloat16& a, const Eigen::bfloat16& b, const Eigen::bfloat16& x) { + return Eigen::bfloat16(Eigen::numext::betainc(static_cast(a), static_cast(b), static_cast(x))); +} +#endif + +} // end namespace numext +} // end namespace Eigen + +#endif // EIGEN_SPECIALFUNCTIONS_BFLOAT16_H diff --git a/unsupported/test/cxx11_tensor_reduction.cpp b/unsupported/test/cxx11_tensor_reduction.cpp index 996dba806..f1ac83b1b 100644 --- a/unsupported/test/cxx11_tensor_reduction.cpp +++ b/unsupported/test/cxx11_tensor_reduction.cpp @@ -511,6 +511,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_reduction) { CALL_SUBTEST(( test_simple_reductions() )); CALL_SUBTEST(( test_simple_reductions() )); CALL_SUBTEST(( test_simple_reductions() )); + CALL_SUBTEST(( test_simple_reductions() )); CALL_SUBTEST(test_reductions_in_expr()); CALL_SUBTEST(test_reductions_in_expr()); CALL_SUBTEST(test_full_reductions()); -- cgit v1.2.3