diff options
author | 2017-11-09 20:45:39 -0800 | |
---|---|---|
committer | 2017-11-10 16:14:41 -0800 | |
commit | 64d9aa1ace99c66f20b65532f633acb34ee3c057 (patch) | |
tree | 42589268a62815c66093ad7185eb507b1562f9fb /tensorflow/core/framework/numeric_types.h | |
parent | 685f604f63a30a8162d8762e9d8d22f171dca85e (diff) |
Add bfloat support to XLA.
This is necessary in providing bfloat support in GPU backend.
RELNOTES: bfloat support is now added to XLA infra.
PiperOrigin-RevId: 175252067
Diffstat (limited to 'tensorflow/core/framework/numeric_types.h')
-rw-r--r-- | tensorflow/core/framework/numeric_types.h | 251 |
1 files changed, 242 insertions, 9 deletions
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index a630bee38d..d005de2af1 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -44,29 +44,262 @@ typedef Eigen::QUInt16 quint16; // see framework/bfloat16.h for description. struct bfloat16 { EIGEN_DEVICE_FUNC bfloat16() {} - EIGEN_DEVICE_FUNC explicit bfloat16(const float v) { - const uint16_t* p = reinterpret_cast<const uint16_t*>(&v); + + explicit EIGEN_DEVICE_FUNC bfloat16(float v) { + uint32_t input; + memcpy(&input, &v, sizeof(uint32_t)); + + if ((~input & 0x7f800000) == 0 && (input & 0x007fffff) != 0) { + // If the value is a NaN, squash it to a qNaN with msb of fraction set, + // this makes sure after truncation we don't end up with an inf. + // + // qNaN magic: All exponent bits set + most significant bit of fraction + // set. + value = 0x7fc0; + } else { + // Fast rounding algorithm that rounds a half value to nearest even. This + // reduces expected error when we convert a large number of floats. Here + // is how it works: + // + // Definitions: + // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits + // with the following tags: + // + // Sign | Exp (8 bits) | Frac (23 bits) + // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT + // + // S: Sign bit. + // E: Exponent bits. + // F: First 6 bits of fraction. + // L: Least significant bit of resulting bfloat16 if we truncate away the + // rest of the float32. This is also the 7th bit of fraction + // R: Rounding bit, 8th bit of fraction. + // T: Sticky bits, rest of fraction, 15 bits. + // + // To round half to nearest even, there are 3 cases where we want to round + // down (simply truncate the result of the bits away, which consists of + // rounding bit and sticky bits) and two cases where we want to round up + // (truncate then add one to the result). + // + // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of + // 1s) as the rounding bias, adds the rounding bias to the input, then + // truncates the last 16 bits away. + // + // To understand how it works, we can analyze this algorithm case by case: + // + // 1. L = 0, R = 0: + // Expect: round down, this is less than half value. + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input may create any carry, depending on + // whether there is any value set to 1 in T bits. + // - R may be set to 1 if there is a carry. + // - L remains 0. + // - Note that this case also handles Inf and -Inf, where all fraction + // bits, including L, R and Ts are all 0. The output remains Inf after + // this algorithm. + // + // 2. L = 1, R = 0: + // Expect: round down, this is less than half value. + // + // Algorithm: + // - Rounding bias: 0x7fff + 1 = 0x8000 + // - Adding rounding bias to input doesn't change sticky bits but + // adds 1 to rounding bit. + // - L remains 1. + // + // 3. L = 0, R = 1, all of T are 0: + // Expect: round down, this is exactly at half, the result is already + // even (L=0). + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input sets all sticky bits to 1, but + // doesn't create a carry. + // - R remains 1. + // - L remains 0. + // + // 4. L = 1, R = 1: + // Expect: round up, this is exactly at half, the result needs to be + // round to the next even number. + // + // Algorithm: + // - Rounding bias: 0x7fff + 1 = 0x8000 + // - Adding rounding bias to input doesn't change sticky bits, but + // creates a carry from rounding bit. + // - The carry sets L to 0, creates another carry bit and propagate + // forward to F bits. + // - If all the F bits are 1, a carry then propagates to the exponent + // bits, which then creates the minimum value with the next exponent + // value. Note that we won't have the case where exponents are all 1, + // since that's either a NaN (handled in the other if condition) or inf + // (handled in case 1). + // + // 5. L = 0, R = 1, any of T is 1: + // Expect: round up, this is greater than half. + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input creates a carry from sticky bits, + // sets rounding bit to 0, then create another carry. + // - The second carry sets L to 1. + // + // Examples: + // + // Exact half value that is already even: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000 + // + // This falls into case 3. We truncate the rest of 16 bits and no + // carry is created into F and L: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 + // + // Exact half value, round to next even number: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000 + // + // This falls into case 4. We create a carry from R and T, + // which then propagates into L and F: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 + // + // + // Max denormal value round to min normal value: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111 + // + // This falls into case 4. We create a carry from R and T, + // propagate into L and F, which then propagates into exponent + // bits: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 + // + // Max normal value round to Inf: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111 + // + // This falls into case 4. We create a carry from R and T, + // propagate into L and F, which then propagates into exponent + // bits: + // + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 + // + // + // Least significant bit of resulting bfloat. + uint32_t lsb = (input >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + input += rounding_bias; + value = static_cast<uint16_t>(input >> 16); + } + } + + template <class T> + explicit EIGEN_DEVICE_FUNC bfloat16(const T& val) + : bfloat16(static_cast<float>(val)) {} + + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const { + float result; + + uint16_t* q = reinterpret_cast<uint16_t*>(&result); + #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - value = p[0]; + q[0] = value; + q[1] = 0; #else - value = p[1]; + q[0] = 0; + q[1] = value; #endif + return result; + } + + EIGEN_DEVICE_FUNC explicit operator bool() const { + return static_cast<bool>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator Eigen::half() const { + return static_cast<Eigen::half>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator short() const { + return static_cast<short>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator int() const { + return static_cast<int>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator char() const { + return static_cast<char>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator signed char() const { + return static_cast<signed char>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned char() const { + return static_cast<unsigned char>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned int() const { + return static_cast<unsigned int>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned long() const { + return static_cast<unsigned long>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned long long() const { + return static_cast<unsigned long long>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator long long() const { + return static_cast<long long>(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator double() const { + return static_cast<double>(float(*this)); } uint16_t value; }; +inline bool operator==(const bfloat16 a, const bfloat16 b) { + return a.value == b.value; +} + +inline bool operator!=(const bfloat16 a, const bfloat16 b) { + return a.value != b.value; +} + } // end namespace tensorflow namespace Eigen { template <> struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {}; -EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a, - const tensorflow::bfloat16 b) { - return a.value == b.value; -} - +using ::tensorflow::operator==; +using ::tensorflow::operator!=; } // namespace Eigen #ifdef COMPILER_MSVC |