diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-25 08:23:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-25 08:27:36 -0700 |
commit | b3771feab49e2122164737a860341727d08c2d8c (patch) | |
tree | 5fb440041db26ef96eb14e7491cb67fe06e7c3d4 /tensorflow/core/lib | |
parent | be3d22844025e42e177a21479f3ae73bc5351c1f (diff) |
This change started with an intention of adding an attribute to cast ops to decide
whether bfloat16 casts should use truncation or rounding.
This is a preparatory change before we switch the default float ==> bfloat16 cast
to use rounding instead of truncation. The attribute added can then be specified
on casts that rely on the truncation, e.g., the TensorFlow send/receive operations.
It later emerged that the choice of doing truncation is useful more generally.
Therefore, this change allows the new attribute to be used by all relevant casts
to use truncation instead of rounding.
PiperOrigin-RevId: 205996367
Diffstat (limited to 'tensorflow/core/lib')
-rw-r--r-- | tensorflow/core/lib/bfloat16/bfloat16.h | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h index 1c130ba300..d6f3f26cd5 100644 --- a/tensorflow/core/lib/bfloat16/bfloat16.h +++ b/tensorflow/core/lib/bfloat16/bfloat16.h @@ -45,17 +45,25 @@ typedef std::complex<double> complex128; struct bfloat16 { B16_DEVICE_FUNC bfloat16() {} - B16_DEVICE_FUNC explicit bfloat16(const float v) { + B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) { + bfloat16 output; if (float_isnan(v)) { - value = NAN_VALUE; - return; + output.value = NAN_VALUE; + return output; } const uint16_t* p = reinterpret_cast<const uint16_t*>(&v); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - value = p[0]; + output.value = p[0]; #else - value = p[1]; + output.value = p[1]; #endif + return output; + } + + B16_DEVICE_FUNC explicit bfloat16(const float v) { + // TODO(asabne) : change the below line to + // value = round_to_bfloat16(v).value; + value = truncate_to_bfloat16(v).value; } B16_DEVICE_FUNC explicit bfloat16(const double val) @@ -169,8 +177,6 @@ struct bfloat16 { // Converts a float point to bfloat16, with round-nearest-to-even as rounding // method. - // TODO(b/69266521): Add a truncate_to_bfloat16 function and make this - // function as default behavior. // TODO: There is a slightly faster implementation (8% faster on CPU) // than this (documented in cl/175987786), that is exponentially harder to // understand and document. Switch to the faster version when converting to |