diff options
Diffstat (limited to 'tensorflow/core/lib/bfloat16/bfloat16.h')
-rw-r--r-- | tensorflow/core/lib/bfloat16/bfloat16.h | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h index 1c130ba300..d6f3f26cd5 100644 --- a/tensorflow/core/lib/bfloat16/bfloat16.h +++ b/tensorflow/core/lib/bfloat16/bfloat16.h @@ -45,17 +45,25 @@ typedef std::complex<double> complex128; struct bfloat16 { B16_DEVICE_FUNC bfloat16() {} - B16_DEVICE_FUNC explicit bfloat16(const float v) { + B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) { + bfloat16 output; if (float_isnan(v)) { - value = NAN_VALUE; - return; + output.value = NAN_VALUE; + return output; } const uint16_t* p = reinterpret_cast<const uint16_t*>(&v); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - value = p[0]; + output.value = p[0]; #else - value = p[1]; + output.value = p[1]; #endif + return output; + } + + B16_DEVICE_FUNC explicit bfloat16(const float v) { + // TODO(asabne) : change the below line to + // value = round_to_bfloat16(v).value; + value = truncate_to_bfloat16(v).value; } B16_DEVICE_FUNC explicit bfloat16(const double val) @@ -169,8 +177,6 @@ struct bfloat16 { // Converts a float point to bfloat16, with round-nearest-to-even as rounding // method. - // TODO(b/69266521): Add a truncate_to_bfloat16 function and make this - // function as default behavior. // TODO: There is a slightly faster implementation (8% faster on CPU) // than this (documented in cl/175987786), that is exponentially harder to // understand and document. Switch to the faster version when converting to |