aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/arch/Default/BFloat16.h4
-rw-r--r--test/bfloat16_float.cpp11
2 files changed, 13 insertions, 2 deletions
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h
index 30f3cd456..3b36c2f23 100644
--- a/Eigen/src/Core/arch/Default/BFloat16.h
+++ b/Eigen/src/Core/arch/Default/BFloat16.h
@@ -256,7 +256,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, In
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
__bfloat16_raw output;
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
- output.value = 0x7FC0;
+ output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
return output;
} else if (std::fabs(v) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
// Flush denormal to +/- 0.
@@ -293,7 +293,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<fals
//
// qNaN magic: All exponent bits set + most significant bit of fraction
// set.
- output.value = 0x7fc0;
+ output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
} else if (std::fabs(ff) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
// Flush denormal to +/- 0.0
output.value = std::signbit(ff) ? 0x8000 : 0;
diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp
index 94226e999..79c868e84 100644
--- a/test/bfloat16_float.cpp
+++ b/test/bfloat16_float.cpp
@@ -230,6 +230,17 @@ void test_conversion()
VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0xffc0))));
VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0x7f80))));
VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0x7fc0))));
+
+ VERIFY_IS_EQUAL(bfloat16(BinaryToFloat(0x0, 0xff, 0x40, 0x0)).value, 0x7fc0);
+ VERIFY_IS_EQUAL(bfloat16(BinaryToFloat(0x1, 0xff, 0x40, 0x0)).value, 0xffc0);
+ VERIFY_IS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16(
+ BinaryToFloat(0x0, 0xff, 0x40, 0x0))
+ .value,
+ 0x7fc0);
+ VERIFY_IS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16(
+ BinaryToFloat(0x1, 0xff, 0x40, 0x0))
+ .value,
+ 0xffc0);
}
void test_numtraits()