aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default/BFloat16.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/arch/Default/BFloat16.h')
-rw-r--r--Eigen/src/Core/arch/Default/BFloat16.h703
1 files changed, 703 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h
new file mode 100644
index 000000000..c3725d473
--- /dev/null
+++ b/Eigen/src/Core/arch/Default/BFloat16.h
@@ -0,0 +1,703 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+
+#ifndef EIGEN_BFLOAT16_H
+#define EIGEN_BFLOAT16_H
+
+#if __cplusplus > 199711L
+#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type()
+#else
+#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
+#endif
+
+namespace Eigen {
+
+struct bfloat16;
+
+namespace bfloat16_impl {
+
+// Make our own __bfloat16_raw definition.
+struct __bfloat16_raw {
+ EIGEN_DEVICE_FUNC __bfloat16_raw() : value(0) {}
+ explicit EIGEN_DEVICE_FUNC __bfloat16_raw(unsigned short raw) : value(raw) {}
+ unsigned short value;
+};
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
+
+struct bfloat16_base : public __bfloat16_raw {
+ EIGEN_DEVICE_FUNC bfloat16_base() {}
+ EIGEN_DEVICE_FUNC bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
+};
+
+} // namespace bfloat16_impl
+
+// Class definition.
+struct bfloat16 : public bfloat16_impl::bfloat16_base {
+
+ typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
+
+ EIGEN_DEVICE_FUNC bfloat16() {}
+
+ EIGEN_DEVICE_FUNC bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
+
+ explicit EIGEN_DEVICE_FUNC bfloat16(bool b)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
+ template<class T>
+ explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val))) {}
+ explicit EIGEN_DEVICE_FUNC bfloat16(float f)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {}
+ // Following the convention of numpy, converting between complex and
+ // float will lead to loss of imag value.
+ // Single precision complex.
+ typedef std::complex<float> complex64;
+ // Double precision complex.
+ typedef std::complex<double> complex128;
+ explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(val.real())) {}
+ explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val.real()))) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
+ // +0.0 and -0.0 become false, everything else becomes true.
+ return (value & 0x7fff) != 0;
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(signed char) const {
+ return static_cast<signed char>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned char) const {
+ return static_cast<unsigned char>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(short) const {
+ return static_cast<short>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned short) const {
+ return static_cast<unsigned short>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(int) const {
+ return static_cast<int>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned int) const {
+ return static_cast<unsigned int>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long) const {
+ return static_cast<long>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long) const {
+ return static_cast<unsigned long>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long long) const {
+ return static_cast<long long>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const {
+ return static_cast<unsigned long long>(bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
+ return bfloat16_impl::bfloat16_to_float(*this);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const {
+ return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex64) const {
+ return complex64(bfloat16_impl::bfloat16_to_float(*this), float(0.0));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex128) const {
+ return complex128(static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)), double(0.0));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const {
+ return static_cast<Eigen::half>(bfloat16_impl::bfloat16_to_float(*this));
+ }
+};
+
+} // end namespace Eigen
+
+namespace std {
+template<>
+struct numeric_limits<Eigen::bfloat16> {
+ static const bool is_specialized = true;
+ static const bool is_signed = true;
+ static const bool is_integer = false;
+ static const bool is_exact = false;
+ static const bool has_infinity = true;
+ static const bool has_quiet_NaN = true;
+ static const bool has_signaling_NaN = true;
+ static const float_denorm_style has_denorm = numeric_limits<float>::has_denorm;
+ static const bool has_denorm_loss = numeric_limits<float>::has_denorm_loss;
+ static const std::float_round_style round_style = numeric_limits<float>::round_style;
+ static const bool is_iec559 = false;
+ static const bool is_bounded = true;
+ static const bool is_modulo = false;
+ static const int digits = 8;
+ static const int digits10 = 2;
+ static const int max_digits10 = 4;
+ static const int radix = 2;
+ static const int min_exponent = numeric_limits<float>::min_exponent;
+ static const int min_exponent10 = numeric_limits<float>::min_exponent10;
+ static const int max_exponent = numeric_limits<float>::max_exponent;
+ static const int max_exponent10 = numeric_limits<float>::max_exponent10;
+ static const bool traps = numeric_limits<float>::traps;
+ static const bool tinyness_before = numeric_limits<float>::tinyness_before;
+
+ static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
+ static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
+ static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
+ static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
+ static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
+ static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
+ static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
+ static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
+ static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
+};
+
+// If std::numeric_limits<T> is specialized, should also specialize
+// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
+// std::numeric_limits<const volatile T>
+// https://stackoverflow.com/a/16519653/
+template<>
+struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
+template<>
+struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
+template<>
+struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
+} // end namespace std
+
+namespace Eigen {
+
+namespace bfloat16_impl {
+
+// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
+// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
+// of the functions, while the latter can only deal with one of them.
+#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
+
+#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
+// We need to provide emulated *host-side* BF16 operators for clang.
+#pragma push_macro("EIGEN_DEVICE_FUNC")
+#undef EIGEN_DEVICE_FUNC
+#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
+#define EIGEN_DEVICE_FUNC __host__
+#else // both host and device need emulated ops.
+#define EIGEN_DEVICE_FUNC __host__ __device__
+#endif
+#endif
+
+// Definitions for CPUs, mostly working through conversion
+// to/from fp32.
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) + float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
+ return bfloat16(float(a) + static_cast<float>(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
+ return bfloat16(static_cast<float>(a) + float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) * float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) - float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) / float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
+ bfloat16 result;
+ result.value = a.value ^ 0x8000;
+ return result;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) + float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) * float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) - float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) / float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
+ a += bfloat16(1);
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
+ a -= bfloat16(1);
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
+ bfloat16 original_value = a;
+ ++a;
+ return original_value;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
+ bfloat16 original_value = a;
+ --a;
+ return original_value;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
+ return numext::equal_strict(float(a),float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
+ return numext::not_equal_strict(float(a), float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
+ return float(a) < float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
+ return float(a) <= float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
+ return float(a) > float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
+ return float(a) >= float(b);
+}
+
+#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
+#pragma pop_macro("EIGEN_DEVICE_FUNC")
+#endif
+#endif // Emulate support for bfloat16 floats
+
+// Division by an index. Do it in full float precision to avoid accuracy
+// issues in converting the denominator to bfloat16.
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
+ return bfloat16(static_cast<float>(a) / static_cast<float>(b));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
+ __bfloat16_raw output;
+ if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
+ output.value = 0x7FC0;
+ return output;
+ } else if (std::fabs(v) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
+ // Flush denormal to +/- 0.
+ output.value = std::signbit(v) ? 0x8000 : 0;
+ return output;
+ }
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ output.value = p[0];
+#else
+ output.value = p[1];
+#endif
+ return output;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value) {
+ __bfloat16_raw h;
+ h.value = value;
+ return h;
+}
+
+union float32_bits {
+ unsigned int u;
+ float f;
+};
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff) {
+#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
+ // Nothing to do here
+#else
+ unsigned int input;
+ float32_bits f;
+ f.f = ff;
+ input = f.u;
+ __bfloat16_raw output;
+
+ if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
+ // If the value is a NaN, squash it to a qNaN with msb of fraction set,
+ // this makes sure after truncation we don't end up with an inf.
+ //
+ // qNaN magic: All exponent bits set + most significant bit of fraction
+ // set.
+ output.value = 0x7fc0;
+ } else if (std::fabs(ff) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
+ // Flush denormal to +/- 0.0
+ output.value = std::signbit(ff) ? 0x8000 : 0;
+ } else {
+ // Fast rounding algorithm that rounds a half value to nearest even. This
+ // reduces expected error when we convert a large number of floats. Here
+ // is how it works:
+ //
+ // Definitions:
+ // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
+ // with the following tags:
+ //
+ // Sign | Exp (8 bits) | Frac (23 bits)
+ // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
+ //
+ // S: Sign bit.
+ // E: Exponent bits.
+ // F: First 6 bits of fraction.
+ // L: Least significant bit of resulting bfloat16 if we truncate away the
+ // rest of the float32. This is also the 7th bit of fraction
+ // R: Rounding bit, 8th bit of fraction.
+ // T: Sticky bits, rest of fraction, 15 bits.
+ //
+ // To round half to nearest even, there are 3 cases where we want to round
+ // down (simply truncate the result of the bits away, which consists of
+ // rounding bit and sticky bits) and two cases where we want to round up
+ // (truncate then add one to the result).
+ //
+ // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
+ // 1s) as the rounding bias, adds the rounding bias to the input, then
+ // truncates the last 16 bits away.
+ //
+ // To understand how it works, we can analyze this algorithm case by case:
+ //
+ // 1. L = 0, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input may create any carry, depending on
+ // whether there is any value set to 1 in T bits.
+ // - R may be set to 1 if there is a carry.
+ // - L remains 0.
+ // - Note that this case also handles Inf and -Inf, where all fraction
+ // bits, including L, R and Ts are all 0. The output remains Inf after
+ // this algorithm.
+ //
+ // 2. L = 1, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits but
+ // adds 1 to rounding bit.
+ // - L remains 1.
+ //
+ // 3. L = 0, R = 1, all of T are 0:
+ // Expect: round down, this is exactly at half, the result is already
+ // even (L=0).
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input sets all sticky bits to 1, but
+ // doesn't create a carry.
+ // - R remains 1.
+ // - L remains 0.
+ //
+ // 4. L = 1, R = 1:
+ // Expect: round up, this is exactly at half, the result needs to be
+ // round to the next even number.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits, but
+ // creates a carry from rounding bit.
+ // - The carry sets L to 0, creates another carry bit and propagate
+ // forward to F bits.
+ // - If all the F bits are 1, a carry then propagates to the exponent
+ // bits, which then creates the minimum value with the next exponent
+ // value. Note that we won't have the case where exponents are all 1,
+ // since that's either a NaN (handled in the other if condition) or inf
+ // (handled in case 1).
+ //
+ // 5. L = 0, R = 1, any of T is 1:
+ // Expect: round up, this is greater than half.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input creates a carry from sticky bits,
+ // sets rounding bit to 0, then create another carry.
+ // - The second carry sets L to 1.
+ //
+ // Examples:
+ //
+ // Exact half value that is already even:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
+ //
+ // This falls into case 3. We truncate the rest of 16 bits and no
+ // carry is created into F and L:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ // Exact half value, round to next even number:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // which then propagates into L and F:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ //
+ // Max denormal value round to min normal value:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
+ //
+ // Max normal value round to Inf:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
+ //
+ //
+ // Least significant bit of resulting bfloat.
+ unsigned int lsb = (input >> 16) & 1;
+ unsigned int rounding_bias = 0x7fff + lsb;
+ input += rounding_bias;
+ output.value = static_cast<unsigned short>(input >> 16);
+ }
+ return output;
+#endif
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
+ float result = 0;
+ unsigned short* q = reinterpret_cast<unsigned short*>(&result);
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ q[0] = h.value;
+#else
+ q[1] = h.value;
+#endif
+ return result;
+}
+// --- standard functions ---
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
+ return std::isinf EIGEN_NOT_A_MACRO(float(a));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
+ return std::isnan EIGEN_NOT_A_MACRO(float(a));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
+ return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
+ bfloat16 result;
+ result.value = a.value & 0x7FFF;
+ return result;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
+ return bfloat16(::expf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
+ return bfloat16(numext::expm1(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
+ return bfloat16(::logf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
+ return bfloat16(numext::log1p(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
+ return bfloat16(::log10f(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
+ return bfloat16(::sqrtf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
+ return bfloat16(::powf(float(a), float(b)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
+ return bfloat16(::sinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
+ return bfloat16(::cosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
+ return bfloat16(::tanf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
+ return bfloat16(::asinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
+ return bfloat16(::acosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
+ return bfloat16(::atanf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
+ return bfloat16(::sinhf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
+ return bfloat16(::coshf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
+ return bfloat16(::tanhf(float(a)));
+}
+#if EIGEN_HAS_CXX11_MATH
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
+ return bfloat16(::asinh(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
+ return bfloat16(::acosh(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
+ return bfloat16(::atanh(float(a)));
+}
+#endif
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
+ return bfloat16(::floorf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
+ return bfloat16(::ceilf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
+ return bfloat16(::fmodf(float(a), float(b)));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f2 < f1 ? b : a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f1 < f2 ? b : a;
+}
+
+#ifndef EIGEN_NO_IO
+EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
+ os << static_cast<float>(v);
+ return os;
+}
+#endif
+
+} // end namespace bfloat16_impl
+
+namespace internal {
+
+template<>
+struct random_default_impl<bfloat16, false, false>
+{
+ static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
+ {
+ return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
+ }
+ static inline bfloat16 run()
+ {
+ return run(bfloat16(-1.f), bfloat16(1.f));
+ }
+};
+
+template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
+
+} // end namespace internal
+
+template<> struct NumTraits<Eigen::bfloat16>
+ : GenericNumTraits<Eigen::bfloat16>
+{
+ enum {
+ IsSigned = true,
+ IsInteger = false,
+ IsComplex = false,
+ RequireInitialization = false
+ };
+
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
+ }
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() { return Eigen::bfloat16(5e-2f); }
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
+ }
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
+ }
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
+ }
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
+ }
+};
+
+} // end namespace Eigen
+
+namespace std {
+
+#if __cplusplus > 199711L
+template <>
+struct hash<Eigen::bfloat16> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
+ return hash<float>()(static_cast<float>(a));
+ }
+};
+#endif
+
+} // end namespace std
+
+
+namespace Eigen {
+namespace numext {
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+bool (isnan)(const Eigen::bfloat16& h) {
+ return (bfloat16_impl::isnan)(h);
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+bool (isinf)(const Eigen::bfloat16& h) {
+ return (bfloat16_impl::isinf)(h);
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+bool (isfinite)(const Eigen::bfloat16& h) {
+ return (bfloat16_impl::isfinite)(h);
+}
+
+} // namespace Eigen
+} // namespace numext
+
+#endif // EIGEN_BFLOAT16_H