aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/numeric_types.h
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2017-11-09 20:45:39 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:41 -0800
commit64d9aa1ace99c66f20b65532f633acb34ee3c057 (patch)
tree42589268a62815c66093ad7185eb507b1562f9fb /tensorflow/core/framework/numeric_types.h
parent685f604f63a30a8162d8762e9d8d22f171dca85e (diff)
Add bfloat support to XLA.
This is necessary in providing bfloat support in GPU backend. RELNOTES: bfloat support is now added to XLA infra. PiperOrigin-RevId: 175252067
Diffstat (limited to 'tensorflow/core/framework/numeric_types.h')
-rw-r--r--tensorflow/core/framework/numeric_types.h251
1 files changed, 242 insertions, 9 deletions
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index a630bee38d..d005de2af1 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -44,29 +44,262 @@ typedef Eigen::QUInt16 quint16;
// see framework/bfloat16.h for description.
struct bfloat16 {
EIGEN_DEVICE_FUNC bfloat16() {}
- EIGEN_DEVICE_FUNC explicit bfloat16(const float v) {
- const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
+
+ explicit EIGEN_DEVICE_FUNC bfloat16(float v) {
+ uint32_t input;
+ memcpy(&input, &v, sizeof(uint32_t));
+
+ if ((~input & 0x7f800000) == 0 && (input & 0x007fffff) != 0) {
+ // If the value is a NaN, squash it to a qNaN with msb of fraction set,
+ // this makes sure after truncation we don't end up with an inf.
+ //
+ // qNaN magic: All exponent bits set + most significant bit of fraction
+ // set.
+ value = 0x7fc0;
+ } else {
+ // Fast rounding algorithm that rounds a half value to nearest even. This
+ // reduces expected error when we convert a large number of floats. Here
+ // is how it works:
+ //
+ // Definitions:
+ // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
+ // with the following tags:
+ //
+ // Sign | Exp (8 bits) | Frac (23 bits)
+ // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
+ //
+ // S: Sign bit.
+ // E: Exponent bits.
+ // F: First 6 bits of fraction.
+ // L: Least significant bit of resulting bfloat16 if we truncate away the
+ // rest of the float32. This is also the 7th bit of fraction
+ // R: Rounding bit, 8th bit of fraction.
+ // T: Sticky bits, rest of fraction, 15 bits.
+ //
+ // To round half to nearest even, there are 3 cases where we want to round
+ // down (simply truncate the result of the bits away, which consists of
+ // rounding bit and sticky bits) and two cases where we want to round up
+ // (truncate then add one to the result).
+ //
+ // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
+ // 1s) as the rounding bias, adds the rounding bias to the input, then
+ // truncates the last 16 bits away.
+ //
+ // To understand how it works, we can analyze this algorithm case by case:
+ //
+ // 1. L = 0, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input may create any carry, depending on
+ // whether there is any value set to 1 in T bits.
+ // - R may be set to 1 if there is a carry.
+ // - L remains 0.
+ // - Note that this case also handles Inf and -Inf, where all fraction
+ // bits, including L, R and Ts are all 0. The output remains Inf after
+ // this algorithm.
+ //
+ // 2. L = 1, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits but
+ // adds 1 to rounding bit.
+ // - L remains 1.
+ //
+ // 3. L = 0, R = 1, all of T are 0:
+ // Expect: round down, this is exactly at half, the result is already
+ // even (L=0).
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input sets all sticky bits to 1, but
+ // doesn't create a carry.
+ // - R remains 1.
+ // - L remains 0.
+ //
+ // 4. L = 1, R = 1:
+ // Expect: round up, this is exactly at half, the result needs to be
+ // round to the next even number.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits, but
+ // creates a carry from rounding bit.
+ // - The carry sets L to 0, creates another carry bit and propagate
+ // forward to F bits.
+ // - If all the F bits are 1, a carry then propagates to the exponent
+ // bits, which then creates the minimum value with the next exponent
+ // value. Note that we won't have the case where exponents are all 1,
+ // since that's either a NaN (handled in the other if condition) or inf
+ // (handled in case 1).
+ //
+ // 5. L = 0, R = 1, any of T is 1:
+ // Expect: round up, this is greater than half.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input creates a carry from sticky bits,
+ // sets rounding bit to 0, then create another carry.
+ // - The second carry sets L to 1.
+ //
+ // Examples:
+ //
+ // Exact half value that is already even:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
+ //
+ // This falls into case 3. We truncate the rest of 16 bits and no
+ // carry is created into F and L:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ // Exact half value, round to next even number:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // which then propagates into L and F:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ //
+ // Max denormal value round to min normal value:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
+ //
+ // Max normal value round to Inf:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
+ //
+ //
+ // Least significant bit of resulting bfloat.
+ uint32_t lsb = (input >> 16) & 1;
+ uint32_t rounding_bias = 0x7fff + lsb;
+ input += rounding_bias;
+ value = static_cast<uint16_t>(input >> 16);
+ }
+ }
+
+ template <class T>
+ explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
+ : bfloat16(static_cast<float>(val)) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
+ float result;
+
+ uint16_t* q = reinterpret_cast<uint16_t*>(&result);
+
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- value = p[0];
+ q[0] = value;
+ q[1] = 0;
#else
- value = p[1];
+ q[0] = 0;
+ q[1] = value;
#endif
+ return result;
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator bool() const {
+ return static_cast<bool>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator Eigen::half() const {
+ return static_cast<Eigen::half>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator short() const {
+ return static_cast<short>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator int() const {
+ return static_cast<int>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator char() const {
+ return static_cast<char>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator signed char() const {
+ return static_cast<signed char>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator unsigned char() const {
+ return static_cast<unsigned char>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator unsigned int() const {
+ return static_cast<unsigned int>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator unsigned long() const {
+ return static_cast<unsigned long>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator unsigned long long() const {
+ return static_cast<unsigned long long>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator long long() const {
+ return static_cast<long long>(float(*this));
+ }
+
+ EIGEN_DEVICE_FUNC explicit operator double() const {
+ return static_cast<double>(float(*this));
}
uint16_t value;
};
+inline bool operator==(const bfloat16 a, const bfloat16 b) {
+ return a.value == b.value;
+}
+
+inline bool operator!=(const bfloat16 a, const bfloat16 b) {
+ return a.value != b.value;
+}
+
} // end namespace tensorflow
namespace Eigen {
template <>
struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
-EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a,
- const tensorflow::bfloat16 b) {
- return a.value == b.value;
-}
-
+using ::tensorflow::operator==;
+using ::tensorflow::operator!=;
} // namespace Eigen
#ifdef COMPILER_MSVC