aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/numeric_types.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/numeric_types.h')
-rw-r--r--tensorflow/core/framework/numeric_types.h251
1 files changed, 9 insertions, 242 deletions
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index d005de2af1..a630bee38d 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -44,262 +44,29 @@ typedef Eigen::QUInt16 quint16;
// see framework/bfloat16.h for description.
struct bfloat16 {
EIGEN_DEVICE_FUNC bfloat16() {}
-
- explicit EIGEN_DEVICE_FUNC bfloat16(float v) {
- uint32_t input;
- memcpy(&input, &v, sizeof(uint32_t));
-
- if ((~input & 0x7f800000) == 0 && (input & 0x007fffff) != 0) {
- // If the value is a NaN, squash it to a qNaN with msb of fraction set,
- // this makes sure after truncation we don't end up with an inf.
- //
- // qNaN magic: All exponent bits set + most significant bit of fraction
- // set.
- value = 0x7fc0;
- } else {
- // Fast rounding algorithm that rounds a half value to nearest even. This
- // reduces expected error when we convert a large number of floats. Here
- // is how it works:
- //
- // Definitions:
- // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
- // with the following tags:
- //
- // Sign | Exp (8 bits) | Frac (23 bits)
- // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
- //
- // S: Sign bit.
- // E: Exponent bits.
- // F: First 6 bits of fraction.
- // L: Least significant bit of resulting bfloat16 if we truncate away the
- // rest of the float32. This is also the 7th bit of fraction
- // R: Rounding bit, 8th bit of fraction.
- // T: Sticky bits, rest of fraction, 15 bits.
- //
- // To round half to nearest even, there are 3 cases where we want to round
- // down (simply truncate the result of the bits away, which consists of
- // rounding bit and sticky bits) and two cases where we want to round up
- // (truncate then add one to the result).
- //
- // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
- // 1s) as the rounding bias, adds the rounding bias to the input, then
- // truncates the last 16 bits away.
- //
- // To understand how it works, we can analyze this algorithm case by case:
- //
- // 1. L = 0, R = 0:
- // Expect: round down, this is less than half value.
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 0 = 0x7fff
- // - Adding rounding bias to input may create any carry, depending on
- // whether there is any value set to 1 in T bits.
- // - R may be set to 1 if there is a carry.
- // - L remains 0.
- // - Note that this case also handles Inf and -Inf, where all fraction
- // bits, including L, R and Ts are all 0. The output remains Inf after
- // this algorithm.
- //
- // 2. L = 1, R = 0:
- // Expect: round down, this is less than half value.
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 1 = 0x8000
- // - Adding rounding bias to input doesn't change sticky bits but
- // adds 1 to rounding bit.
- // - L remains 1.
- //
- // 3. L = 0, R = 1, all of T are 0:
- // Expect: round down, this is exactly at half, the result is already
- // even (L=0).
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 0 = 0x7fff
- // - Adding rounding bias to input sets all sticky bits to 1, but
- // doesn't create a carry.
- // - R remains 1.
- // - L remains 0.
- //
- // 4. L = 1, R = 1:
- // Expect: round up, this is exactly at half, the result needs to be
- // round to the next even number.
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 1 = 0x8000
- // - Adding rounding bias to input doesn't change sticky bits, but
- // creates a carry from rounding bit.
- // - The carry sets L to 0, creates another carry bit and propagate
- // forward to F bits.
- // - If all the F bits are 1, a carry then propagates to the exponent
- // bits, which then creates the minimum value with the next exponent
- // value. Note that we won't have the case where exponents are all 1,
- // since that's either a NaN (handled in the other if condition) or inf
- // (handled in case 1).
- //
- // 5. L = 0, R = 1, any of T is 1:
- // Expect: round up, this is greater than half.
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 0 = 0x7fff
- // - Adding rounding bias to input creates a carry from sticky bits,
- // sets rounding bit to 0, then create another carry.
- // - The second carry sets L to 1.
- //
- // Examples:
- //
- // Exact half value that is already even:
- // Input:
- // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
- // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
- // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
- //
- // This falls into case 3. We truncate the rest of 16 bits and no
- // carry is created into F and L:
- //
- // Output:
- // Sign | Exp (8 bit) | Frac (first 7 bit)
- // S E E E E E E E E F F F F F F L
- // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
- //
- // Exact half value, round to next even number:
- // Input:
- // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
- // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
- // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
- //
- // This falls into case 4. We create a carry from R and T,
- // which then propagates into L and F:
- //
- // Output:
- // Sign | Exp (8 bit) | Frac (first 7 bit)
- // S E E E E E E E E F F F F F F L
- // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
- //
- //
- // Max denormal value round to min normal value:
- // Input:
- // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
- // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
- // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
- //
- // This falls into case 4. We create a carry from R and T,
- // propagate into L and F, which then propagates into exponent
- // bits:
- //
- // Output:
- // Sign | Exp (8 bit) | Frac (first 7 bit)
- // S E E E E E E E E F F F F F F L
- // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
- //
- // Max normal value round to Inf:
- // Input:
- // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
- // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
- // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
- //
- // This falls into case 4. We create a carry from R and T,
- // propagate into L and F, which then propagates into exponent
- // bits:
- //
- // Sign | Exp (8 bit) | Frac (first 7 bit)
- // S E E E E E E E E F F F F F F L
- // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
- //
- //
- // Least significant bit of resulting bfloat.
- uint32_t lsb = (input >> 16) & 1;
- uint32_t rounding_bias = 0x7fff + lsb;
- input += rounding_bias;
- value = static_cast<uint16_t>(input >> 16);
- }
- }
-
- template <class T>
- explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
- : bfloat16(static_cast<float>(val)) {}
-
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
- float result;
-
- uint16_t* q = reinterpret_cast<uint16_t*>(&result);
-
+ EIGEN_DEVICE_FUNC explicit bfloat16(const float v) {
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- q[0] = value;
- q[1] = 0;
+ value = p[0];
#else
- q[0] = 0;
- q[1] = value;
+ value = p[1];
#endif
- return result;
- }
-
- EIGEN_DEVICE_FUNC explicit operator bool() const {
- return static_cast<bool>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator Eigen::half() const {
- return static_cast<Eigen::half>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator short() const {
- return static_cast<short>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator int() const {
- return static_cast<int>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator char() const {
- return static_cast<char>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator signed char() const {
- return static_cast<signed char>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator unsigned char() const {
- return static_cast<unsigned char>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator unsigned int() const {
- return static_cast<unsigned int>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator unsigned long() const {
- return static_cast<unsigned long>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator unsigned long long() const {
- return static_cast<unsigned long long>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator long long() const {
- return static_cast<long long>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator double() const {
- return static_cast<double>(float(*this));
}
uint16_t value;
};
-inline bool operator==(const bfloat16 a, const bfloat16 b) {
- return a.value == b.value;
-}
-
-inline bool operator!=(const bfloat16 a, const bfloat16 b) {
- return a.value != b.value;
-}
-
} // end namespace tensorflow
namespace Eigen {
template <>
struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
-using ::tensorflow::operator==;
-using ::tensorflow::operator!=;
+EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a,
+ const tensorflow::bfloat16 b) {
+ return a.value == b.value;
+}
+
} // namespace Eigen
#ifdef COMPILER_MSVC