aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/bfloat16/bfloat16.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/lib/bfloat16/bfloat16.h')
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h20
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index 1c130ba300..d6f3f26cd5 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -45,17 +45,25 @@ typedef std::complex<double> complex128;
struct bfloat16 {
B16_DEVICE_FUNC bfloat16() {}
- B16_DEVICE_FUNC explicit bfloat16(const float v) {
+ B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
+ bfloat16 output;
if (float_isnan(v)) {
- value = NAN_VALUE;
- return;
+ output.value = NAN_VALUE;
+ return output;
}
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- value = p[0];
+ output.value = p[0];
#else
- value = p[1];
+ output.value = p[1];
#endif
+ return output;
+ }
+
+ B16_DEVICE_FUNC explicit bfloat16(const float v) {
+ // TODO(asabne) : change the below line to
+ // value = round_to_bfloat16(v).value;
+ value = truncate_to_bfloat16(v).value;
}
B16_DEVICE_FUNC explicit bfloat16(const double val)
@@ -169,8 +177,6 @@ struct bfloat16 {
// Converts a float point to bfloat16, with round-nearest-to-even as rounding
// method.
- // TODO(b/69266521): Add a truncate_to_bfloat16 function and make this
- // function as default behavior.
// TODO: There is a slightly faster implementation (8% faster on CPU)
// than this (documented in cl/175987786), that is exponentially harder to
// understand and document. Switch to the faster version when converting to